| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file depthwise_convolution_tf.cuh |
| * \brief some depthwise convolution CUDA kernel code. The main logic comes |
| * from tensorflow, but the filter's layerout and many argument names |
| * are different with origin version. |
| * \author shuqian.qu@hobot.cc |
| */ |
| #ifndef MXNET_OPERATOR_NN_DEPTHWISE_CONVOLUTION_TF_CUH_ |
| #define MXNET_OPERATOR_NN_DEPTHWISE_CONVOLUTION_TF_CUH_ |
| #include "../../common/cuda/utils.h" |
| #include "../mxnet_op.h" |
| |
| namespace tf { |
| namespace depthwise_conv { |
| |
| #define FULL_WARP_MASK 0xFFFFFFFF |
| #if CUDA_VERSION < 9000 |
| template<typename DType> |
| __forceinline__ __device__ DType __shfl_xor_sync(unsigned, DType val, int delta) { |
| return __shfl_xor(val, delta); |
| } |
| |
| template<typename DType> |
| __forceinline__ __device__ DType __shfl_down_sync(unsigned, DType val, int delta) { |
| return __shfl_down(val, delta); |
| } |
| |
| // shuffle masks not used before CUDA 9. |
| #define CREATE_SHFL_MASK(mask, predicate) \ |
| unsigned mask = 0u; |
| #else |
| #define CREATE_SHFL_MASK(mask, predicate) \ |
| unsigned mask = __ballot_sync(FULL_WARP_MASK, (predicate)) |
| #endif |
| |
| struct DepthwiseArgs { |
| // Input layer dimensions |
| int batch; |
| int in_height; |
| int in_width; |
| int in_channel; |
| int filter_height; |
| int filter_width; |
| int stride_height; |
| int stride_width; |
| int pad_height; |
| int pad_width; |
| |
| // Output layer dimensions |
| int out_height; |
| int out_width; |
| int out_channel; |
| }; |
| |
| namespace cuda { |
| template<typename DType, int kFilterHeight, int kFilterWidth> |
| __global__ void __launch_bounds__(1024, 2) |
| DepthwiseConv2dForwardKernel(const DType* input, |
| const DType* filter, |
| const DepthwiseArgs args, |
| int num_outputs, |
| DType* output) { |
| const int in_channel = args.in_channel; |
| const int in_height = args.in_height; |
| const int in_width = args.in_width; |
| const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; |
| const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; |
| const int stride_height = args.stride_height; |
| const int stride_width = args.stride_width; |
| const int pad_height = args.pad_height; |
| const int pad_width = args.pad_width; |
| const int out_channel = args.out_channel; |
| const int out_height = args.out_height; |
| const int out_width = args.out_width; |
| |
| CUDA_KERNEL_LOOP(thread_id, num_outputs) { |
| // Compute the indexes of this thread in the output. |
| // |
| // We want coalesced reads so we make sure that each warp reads |
| // a contiguous chunk of memory. |
| // |
| // THIS IS PROBABLY WRONG, we are not doing coalesced reads |
| // into the input, because of the depth multiplier division... |
| const int out_w = thread_id % out_width; |
| const int out_h = (thread_id / out_width) % out_height; |
| const int out_c = (thread_id / out_width / out_height) % out_channel; |
| const int out_b = thread_id / out_width / out_height / out_channel; |
| const int in_c = out_c; |
| |
| // Data is stored in the following format (let's assume we |
| // flatten the height and width into one contiguous dimension |
| // called "P". |
| // |
| // B1C1P1 B1C1P2 ..... B1C2P1 B1C2P2 .... |
| // B2C1P1 B2C1P2 ..... B2C2P1 B2C2P2 .... |
| // |
| // Each row contains in_channel * in_height * in_width values |
| // for each sample in the batch. |
| // |
| // We can further flatten it into: |
| // |
| // B1C1P1 B1C1P2 ..... |
| // B1C2P1 B1C2P2 .... |
| // B2C1P1 B2C1P2 ..... |
| // B2C2P1 B2C2P2 .... |
| // |
| // where each row is a contiguous array of all of the spatial |
| // pixels for a given batch and input depth. The following |
| // loop unrolls across the filter dimensions for a given thread, |
| // indexing into the filter value and the corresponding input |
| // patch. |
| // |
| // We can compute the index into the patch once right here. |
| const int input_offset_temp = (out_b * in_channel + in_c) * (in_height * in_width); |
| const int filter_offset_temp = in_c * filter_height * filter_width; |
| |
| // Finally, we can iterate over the spatial dimensions and perform the |
| // convolution, writing into the output at the end. |
| // |
| // We perform an additional optimization, where we can determine |
| // whether the patch fits within the image indices statically, and |
| // avoid boundary checking within the loop. |
| const int input_h_start = out_h * stride_height - pad_height; |
| const int input_w_start = out_w * stride_width - pad_width; |
| const int input_h_end = input_h_start + filter_height; |
| const int input_w_end = input_w_start + filter_width; |
| |
| DType sum = 0; |
| if (input_h_start >= 0 && input_w_start >= 0 && |
| input_h_end < in_height && input_w_end < in_width) { |
| // Loop that doesn't need to check for boundary conditions. |
| CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { |
| const int in_h = input_h_start + f_h; |
| const int filter_offset_h = filter_width * f_h; |
| CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { |
| const int in_w = input_w_start + f_w; |
| const int input_offset = (input_offset_temp) + (in_h * in_width) + in_w; |
| const int filter_offset = filter_offset_temp + filter_offset_h + f_w; |
| sum += ldg(input + input_offset) * ldg(filter + filter_offset); |
| } |
| } |
| } else { |
| // Loop that needs to check for boundary conditions. |
| CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { |
| const int in_h = input_h_start + f_h; |
| const int filter_offset_h = filter_width * f_h; |
| CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { |
| const int in_w = input_w_start + f_w; |
| // TODO(vrv): the in_h check can be done outside of this loop; |
| // benchmark both methods to determine the better decision. |
| if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) { |
| const int in_w = input_w_start + f_w; |
| const int input_offset = input_offset_temp + (in_h * in_width) + in_w; |
| const int filter_offset = filter_offset_temp + filter_offset_h + f_w; |
| sum += ldg(input + input_offset) * ldg(filter + filter_offset); |
| } |
| } |
| } |
| } |
| output[thread_id] = sum; |
| } |
| } |
| |
| // The DepthwiseConv2dKernelSmall perform either forward or backward input |
| // convolution depending on a template argument of this enum. |
| enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; |
| |
| // CUDA kernel to compute the depthwise convolution forward pass in NCHW format, |
| // tailored for small images up to 32x32. Only use this kernel if |
| // CanLaunchDepthwiseConv2dGPUSmall(args) returns true. |
| // Tiles of the input and filter tensors are loaded into shared memory before |
| // performing the convolution. Each thread handles two elements per iteration, |
| // one each in the lower and upper half of a tile. |
| // Backward input direction is the same as forward direction with the filter |
| // rotated by 180°. |
| template <typename DType, DepthwiseConv2dDirection kDirection, |
| int kBlockSlices, bool kEvenHeight, int kFilterHeight, int kFilterWidth> |
| __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dKernelSmall( |
| const DepthwiseArgs args, const DType* input, const DType* filter, DType* output) { |
| extern __shared__ __align__(sizeof(DType)) unsigned char shared_memory[]; |
| DType* const shared_data = reinterpret_cast<DType*>(shared_memory); |
| |
| const int in_height = args.in_height; |
| const int in_width = args.in_width; |
| const int in_channel = args.in_channel; |
| const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; |
| const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; |
| const int pad_height = args.pad_height; |
| const int pad_width = args.pad_width; |
| |
| // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16. |
| const int block_height = blockDim.y; |
| |
| // These values are the same for all threads and could |
| // be precomputed on the CPU. |
| const int block_pixels = in_width * block_height; |
| const int block_size = block_pixels * kBlockSlices; |
| const int in_pixels = in_width * in_height; |
| const int in_increment = in_width - 1; |
| const int filter_pixels = filter_height * filter_width; |
| const int tile_width = in_width + filter_width - 1; |
| const int even_height = kEvenHeight || (1 & ~in_height); |
| const int tile_height = in_height + filter_height - even_height; |
| const int tile_pixels = tile_width * tile_height; |
| const int tile_size = tile_pixels * kBlockSlices; |
| const int tile_offset = block_height * tile_width; |
| const int pad_offset = pad_height * tile_width + pad_width; |
| const int in_slices = in_channel * args.batch; |
| const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; |
| |
| const int thread_width = threadIdx.x; |
| const int thread_height = threadIdx.y; |
| const int thread_channel = threadIdx.z; |
| |
| // Position in block. |
| const int thread_pix = thread_height * in_width + thread_width; |
| const int thread_idx = thread_channel * block_pixels + thread_pix; |
| |
| // Initialize tile, in particular the padding. |
| for (int i = thread_idx; i < tile_size; i += block_size) { |
| shared_data[i] = DType(0); |
| } |
| __syncthreads(); |
| |
| // Position in tensors. |
| const int tensor_idx = thread_channel * in_pixels + thread_pix; |
| |
| // Position in (padded) shared memory. |
| const int data_pix = thread_height * tile_width + thread_width; |
| const int data_idx = thread_channel * tile_pixels + data_pix; |
| |
| // Position in shared memory, offset by pad_height / pad_width. |
| const int tile_idx = data_idx + pad_offset; |
| |
| const int filter_pix = thread_pix; |
| const int filter_channel = thread_channel; |
| const int filter_idx = filter_pixels * filter_channel + filter_pix; |
| |
| const int max_slice = in_slices - thread_channel; |
| const int filter_write_offset = filter_pix < filter_pixels ? tile_size + filter_idx : 0; |
| const int filter_read_offset = tile_size + |
| (kDirection == DIRECTION_FORWARD ? |
| filter_pixels * filter_channel : filter_pixels * (filter_channel + 1)); |
| const bool skip_second = !kEvenHeight && thread_height + (in_height & 1) == block_height; |
| |
| for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { |
| const int slice = b * kBlockSlices; |
| |
| const int inout_offset = slice * in_pixels + tensor_idx; |
| const bool slice_in_range = slice < max_slice; |
| |
| if (slice_in_range) { |
| const DType* const in_ptr = inout_offset + input; |
| DType* const tile_ptr = tile_idx + shared_data; |
| tile_ptr[0] = ldg(in_ptr); |
| if (!skip_second) { |
| tile_ptr[tile_offset] = ldg(block_pixels + in_ptr); |
| } |
| } |
| |
| if (filter_write_offset != 0) { |
| const int filter_offset = ((slice + filter_channel) % in_channel)* filter_pixels + filter_pix; |
| shared_data[filter_write_offset] = ldg(filter_offset + filter); |
| } |
| |
| // Note: the condition to reach this is uniform across the entire block. |
| __syncthreads(); |
| |
| if (slice_in_range) { |
| DType sum1 = 0; |
| DType sum2 = 0; |
| int shared_offset = data_idx; |
| const DType* filter_ptr = filter_read_offset + shared_data; |
| CUDA_UNROLL for (int r = 0; r < filter_height; ++r) { |
| CUDA_UNROLL for (int c = 0; c < filter_width; ++c) { |
| if (kDirection == DIRECTION_BACKWARD) { |
| filter_ptr--; |
| } |
| const DType filter_value = *filter_ptr; |
| const DType* const tile_ptr = shared_offset + shared_data; |
| sum1 += filter_value * tile_ptr[0]; |
| sum2 += filter_value * tile_ptr[tile_offset]; |
| ++shared_offset; |
| if (kDirection == DIRECTION_FORWARD) { |
| filter_ptr++; |
| } |
| } |
| shared_offset += in_increment; |
| } |
| DType* const out_ptr = inout_offset + output; |
| if (kDirection == DIRECTION_FORWARD) { |
| out_ptr[0] = sum1; |
| if (!skip_second) { |
| out_ptr[block_pixels] = sum2; |
| } |
| } else { |
| out_ptr[0] += sum1; |
| if (!skip_second) { |
| out_ptr[block_pixels] += sum2; |
| } |
| } |
| } |
| |
| // Note: the condition to reach this is uniform across the entire block. |
| __syncthreads(); |
| } |
| } |
| |
| template<typename DType> |
| __global__ void __launch_bounds__(640, 2) |
| DepthwiseConv2dBackwardDataKernel(const DepthwiseArgs args, |
| const DType* out_grad, |
| const DType* filter, DType* in_grad, |
| int num_in_grad) { |
| const int channel = args.in_channel; |
| const int in_height = args.in_height; |
| const int in_width = args.in_width; |
| const int filter_height = args.filter_height; |
| const int filter_width = args.filter_width; |
| const int stride_height = args.stride_height; |
| const int stride_width = args.stride_width; |
| const int pad_height = args.pad_height; |
| const int pad_width = args.pad_width; |
| const int out_height = args.out_height; |
| const int out_width = args.out_width; |
| |
| const int in_pixels = in_height * in_width; |
| const int out_pixels = out_height * out_width; |
| |
| CUDA_KERNEL_LOOP(thread_id, num_in_grad) { |
| // Compute the indexes of this thread in the input. |
| const int in_w = thread_id % in_width; |
| const int in_h = (thread_id / in_width) % in_height; |
| const int channel_idx = (thread_id / in_width / in_height) % channel; |
| const int batch_idx = thread_id / channel / in_width / in_height; |
| DType sum = 0.0f; |
| |
| const int out_h_start = mxnet::common::cuda::CudaMax<int>( |
| 0, (in_h - filter_height + pad_height + stride_height) / stride_height); |
| const int out_h_end = mxnet::common::cuda::CudaMin( |
| out_height - 1, (in_h + pad_height) / stride_height); |
| const int out_w_start = mxnet::common::cuda::CudaMax<int>( |
| 0, (in_w - filter_width + pad_width + stride_width) / stride_width); |
| const int out_w_end = mxnet::common::cuda::CudaMin( |
| out_width - 1, (in_w + pad_width) / stride_width); |
| |
| const int filter_offset_temp = channel_idx * filter_height * filter_width; |
| const int out_grad_offset_temp = (batch_idx * channel * out_pixels) + |
| (channel_idx * out_pixels); |
| |
| for (int out_h = out_h_start; out_h <= out_h_end; ++out_h) { |
| const int f_h = in_h + pad_height - out_h * stride_height; |
| const int filter_offset_h = filter_offset_temp + f_h * filter_width; |
| const int out_grad_offset_h = out_grad_offset_temp + out_h * out_width; |
| for (int out_w = out_w_start; out_w <= out_w_end; ++out_w) { |
| const int f_w = in_w + pad_width - out_w * stride_width; |
| const int filter_offset = filter_offset_h + f_w; |
| const int out_grad_offset = out_grad_offset_h + out_w; |
| sum += ldg(out_grad + out_grad_offset) * ldg(filter + filter_offset); |
| } |
| } |
| const int in_grad_offset = (batch_idx * channel * in_pixels) + |
| (channel_idx * in_pixels) + (in_h * in_width) + (in_w); |
| in_grad[in_grad_offset] += sum; |
| } |
| } |
| |
| // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. |
| template <typename DType, int kFilterWidth, int kFilterHeight> |
| __global__ void __launch_bounds__(640, 2) |
| DepthwiseConv2dBackwardFilterKernel(const DepthwiseArgs args, |
| const DType* out_backprop, |
| const DType* input, |
| DType* filter_backprop, |
| int num_out_backprop) { |
| const int in_channel = args.in_channel; |
| const int in_height = args.in_height; |
| const int in_width = args.in_width; |
| const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; |
| const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; |
| const int stride_height = args.stride_height; |
| const int stride_width = args.stride_width; |
| const int pad_height = args.pad_height; |
| const int pad_width = args.pad_width; |
| const int out_channel = args.out_channel; |
| const int out_height = args.out_height; |
| const int out_width = args.out_width; |
| |
| CUDA_KERNEL_LOOP(thread_id, num_out_backprop) { |
| // Compute the indexes of this thread in the output. |
| const int out_w = thread_id % out_width; |
| const int out_h = (thread_id / out_width) % out_height; |
| const int out_c = (thread_id / out_width / out_height) % out_channel; |
| const int out_b = thread_id / out_width / out_height / out_channel; |
| const int in_c = out_c; |
| |
| // Decide if all input is valid, if yes, we can skip the boundary checks |
| // for each input. |
| const int in_row_start = out_h * stride_height - pad_height; |
| const int in_col_start = out_w * stride_width - pad_width; |
| const int in_row_end = in_row_start + filter_height; |
| const int in_col_end = in_col_start + filter_width; |
| |
| const int out_backprop_offset = |
| (out_b * out_channel * out_height * out_width) + |
| (out_c * out_height * out_width) + (out_h * out_width) + |
| (out_w); |
| |
| const DType out_bp = ldg(out_backprop + out_backprop_offset); |
| if (in_row_start >= 0 && in_col_start >= 0 && |
| in_row_end < in_height && in_col_end < in_width) { |
| CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { |
| const int in_row = in_row_start + f_h; |
| // Avoid repeated computation. |
| const int input_offset_temp = |
| (out_b * in_channel * in_height * in_width) + |
| (in_c * in_height * in_width) + (in_row * in_width); |
| const int filter_backprop_temp = |
| (in_c * filter_width * filter_height) + |
| (filter_width * f_h); |
| |
| CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { |
| const int in_col = in_col_start + f_w; |
| const int input_offset = input_offset_temp + in_col; |
| DType partial_sum = ldg(input + input_offset) * out_bp; |
| DType* addr = filter_backprop + (filter_backprop_temp + f_w); |
| atomicAdd(addr, partial_sum); |
| } |
| } |
| } else { |
| CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { |
| const int in_row = in_row_start + f_h; |
| // Avoid repeated computation. |
| const int input_offset_temp = |
| (out_b * in_channel * in_height * in_width) + |
| (in_c * in_height * in_width) + (in_row * in_width); |
| const int filter_backprop_temp = |
| (in_c * filter_width * filter_height) + |
| (filter_width * f_h); |
| CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { |
| const int in_col = in_col_start + f_w; |
| |
| if (in_row >= 0 && in_row < in_height && in_col >= 0 && in_col < in_width) { |
| const int input_offset = input_offset_temp + in_col; |
| DType partial_sum = ldg(input + input_offset) * out_bp; |
| DType* addr = filter_backprop + (filter_backprop_temp + f_w); |
| // Potentially many threads can add to the same address so we have |
| // to use atomic add here. |
| // TODO(jmchen): If atomic add turns out to be slow, we can: |
| // 1. allocate multiple buffers for the gradients (one for each |
| // example in a batch, for example). This can reduce the |
| // contention on the destination; 2. Have each thread compute one |
| // gradient for an element in the filters. This should work well |
| // when the input depth is big and filter size is not too small. |
| atomicAdd(addr, partial_sum); |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| // CUDA kernel to compute the depthwise convolution backward w.r.t. filter in |
| // NCHW format, tailored for small images up to 32x32. Only use this kernel if |
| // CanLaunchDepthwiseConv2dGPUSmall(args) returns true. |
| // Tiles of the input tensor are loaded into shared memory before performing the |
| // convolution. Per iteration and filter element, each thread first performs |
| // a partial convolution for two elements, one each in the lower and upper half |
| // of a tile. The intermediate result of all pixels of a warp are then |
| // accumulated and written to shared memory. Finally, the values in shared |
| // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed |
| // up in global memory using atomics. |
| // Requirements: threads per block must be multiple of 32 and <= launch_bounds, |
| // kAccumPixels * 64 >= args.in_height * args.in_width * kBlockSlices. |
| template <typename DType, int kBlockSlices, int kAccumPixels, int kFilterHeight, int kFilterWidth> |
| __global__ |
| __launch_bounds__(1024, 2) void DepthwiseConv2dBackwardFilterKernelSmall( |
| const DepthwiseArgs args, const DType* output, const DType* input, DType* filter) { |
| extern __shared__ __align__(sizeof(DType)) unsigned char shared_memory[]; |
| DType* const shared_data = reinterpret_cast<DType*>(shared_memory); |
| |
| const int in_height = args.in_height; |
| const int in_width = blockDim.x; // slower (see b/62280718): args.in_width; |
| const int in_channel = args.in_channel; |
| const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; |
| const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; |
| const int pad_height = args.pad_height; |
| const int pad_width = args.pad_width; |
| |
| const int block_height = blockDim.y; |
| |
| // These values are the same for all threads and could |
| // be precomputed on the CPU. |
| const int block_pixels = in_width * block_height; |
| const int block_size = block_pixels * kBlockSlices; |
| assert((block_size & 31) == 0); |
| const int in_pixels = in_width * in_height; |
| const int in_increment = in_width - 1; |
| const int filter_pixels = filter_height * filter_width; |
| const int tile_width = in_width + filter_width - 1; |
| const int tile_height = 2 * block_height + filter_height - 1; |
| const int tile_pixels = tile_width * tile_height; |
| const int tile_size = tile_pixels * kBlockSlices; |
| const int tile_offset = block_height * tile_width; |
| const int pad_offset = pad_height * tile_width + pad_width; |
| const int in_slices = in_channel * args.batch; |
| const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; |
| // The accumulator has a fixed number of pixels that can be reduced by one |
| // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written. |
| assert(kAccumPixels * 64 >= in_height * in_width * kBlockSlices); |
| const int accum_increment = kAccumPixels * kBlockSlices; |
| const int accum_size = filter_pixels * accum_increment; |
| |
| const int thread_width = threadIdx.x; |
| const int thread_height = threadIdx.y; |
| const int thread_channel = threadIdx.z; |
| |
| // Position in block. |
| const int thread_pix = thread_height * in_width + thread_width; |
| const int thread_idx = thread_channel * block_pixels + thread_pix; |
| |
| // Initialize tile, in particular the padding and accumulator. |
| for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { |
| shared_data[i] = DType(0); |
| } |
| __syncthreads(); |
| |
| // Position in tensors. |
| const int tensor_idx = thread_channel * in_pixels + thread_pix; |
| |
| // Position in (padded) shared memory. |
| const int data_pix = thread_height * tile_width + thread_width; |
| const int data_idx = thread_channel * tile_pixels + data_pix; |
| |
| // Position in shared memory, offset by pad_height / pad_width. |
| const int tile_idx = data_idx + pad_offset; |
| |
| // Position in accumulator (kBlockSlices per warp, depth major). |
| const int accum_pix = thread_pix / (32 / kBlockSlices); |
| const int accum_idx = thread_channel * kAccumPixels + accum_pix; |
| |
| const int max_slice = in_slices - thread_channel; |
| const int accum_offset = tile_size + accum_idx; |
| const bool skip_second = block_height + thread_height >= in_height; |
| |
| for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { |
| const int slice = b * kBlockSlices; |
| |
| const int inout_offset = slice * in_pixels + tensor_idx; |
| const bool slice_in_range = slice < max_slice; |
| |
| if (slice_in_range) { |
| const DType* const in_ptr = inout_offset + input; |
| DType* const tile_ptr = tile_idx + shared_data; |
| tile_ptr[0] = ldg(in_ptr); |
| if (!skip_second) { |
| tile_ptr[tile_offset] = ldg(block_pixels + in_ptr); |
| } |
| } |
| |
| // Note: the condition to reach this is uniform across the entire block. |
| __syncthreads(); |
| |
| // Not all threads of a warp may reach the __shfl_down_sync instruction |
| // so we cannot use the FULL_WARP_MASK there |
| CREATE_SHFL_MASK(active_threads, slice_in_range); |
| |
| if (slice_in_range) { |
| const DType* const out_ptr = inout_offset + output; |
| const DType out1 = ldg(out_ptr); |
| const DType out2 = skip_second ? DType(0) : ldg(block_pixels + out_ptr); |
| int shared_offset = data_idx; |
| DType* accum_ptr = accum_offset + shared_data; |
| CUDA_UNROLL for (int r = 0; r < filter_height; ++r) { |
| CUDA_UNROLL for (int c = 0; c < filter_width; ++c) { |
| const DType* const tile_ptr = shared_offset + shared_data; |
| DType val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; |
| // Warp-accumulate pixels of the same depth and write to accumulator. |
| for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) { |
| val += __shfl_down_sync(active_threads, val, delta); |
| } |
| if (!(thread_idx & 32 / kBlockSlices - 1)) { |
| *accum_ptr = val; |
| } |
| ++shared_offset; |
| accum_ptr += accum_increment; |
| } |
| shared_offset += in_increment; |
| } |
| } |
| |
| // Note: the condition to reach this is uniform across the entire block. |
| __syncthreads(); |
| |
| const DType* const accum_data = tile_size + shared_data; |
| for (int i = thread_idx; i < accum_size; i += block_size) { |
| const int filter_idx = i / kAccumPixels; |
| const int filter_pix = filter_idx / kBlockSlices; |
| const int filter_channel = (slice + filter_idx % kBlockSlices) % in_channel; |
| // convert to CHW |
| const int filter_offset = filter_channel * filter_pixels + |
| (filter_pix/filter_width) * filter_height + filter_pix % filter_width; |
| |
| if (filter_channel < in_channel) { |
| DType val = accum_data[i]; |
| // Warp-accumulate pixels of the same depth from the accumulator. |
| int lane_id; |
| asm volatile ("mov.u32 %0, %laneid;" : "=r"(lane_id)); |
| int sub_warp = lane_id / kAccumPixels; |
| int zeros = sub_warp * kAccumPixels; |
| unsigned mask = (kAccumPixels == 32) ? FULL_WARP_MASK : (((1U << kAccumPixels) - 1) << zeros); |
| for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) { |
| val += __shfl_xor_sync(mask, val, delta); |
| } |
| if (!(thread_idx & kAccumPixels - 1)) { |
| atomicAdd(filter_offset + filter, val); |
| } |
| } |
| } |
| } |
| } |
| |
| |
| } // namespace cuda |
| |
| // Returns whether depthwise convolution forward or backward input pass can be |
| // performed using the faster ('Small') variant of the kernel. |
| bool CanLaunchDepthwiseConv2dGPUSmall(const DepthwiseArgs& args) { |
| return args.stride_height == 1 && args.stride_width == 1 && args.in_height <= 32 && |
| args.in_width <= 32 && args.in_height == args.out_height && |
| args.in_width == args.out_width && args.pad_height >= 0 && |
| args.pad_height < args.filter_height && args.pad_width >= 0 && |
| args.pad_width < args.filter_width && |
| args.filter_height * args.filter_width <= (args.in_height + 1) / 2 * args.in_width; |
| } |
| |
| // Returns whether depthwise convolution backward filter pass can be performed |
| // using the faster ('Small') variant of the kernel. |
| bool CanLaunchDepthwiseConv2dBackwardFilterGPUSmall(const DepthwiseArgs args, |
| const int block_height) { |
| return args.stride_height == 1 && args.stride_width == 1 && args.in_height <= 32 && |
| args.in_width <= 32 && args.in_height == args.out_height && |
| args.in_width == args.out_width && args.pad_height >= 0 && |
| args.pad_height < args.filter_height && args.pad_width >= 0 && |
| args.pad_width < args.filter_width && block_height <= args.in_height && |
| args.filter_height * args.filter_width <= block_height * args.in_width; |
| } |
| |
| template <typename DType, cuda::DepthwiseConv2dDirection kDirection, |
| int kBlockSlices, bool kEvenHeight> |
| void LaunchDepthwiseConv2dGPUSmall(mshadow::Stream<mxnet::gpu> *stream, |
| const DepthwiseArgs args, |
| const DType* input, const DType* filter, DType* output) { |
| const int block_height = (args.in_height + 1) / 2; |
| dim3 block_dim = dim3(args.in_width, block_height, kBlockSlices); |
| |
| const int tile_width = args.in_width + args.filter_width - 1; |
| const int tile_height = block_height * 2 + args.filter_height - 1; |
| const int tile_pixels = tile_height * tile_width; |
| const int filter_pixels = args.filter_height * args.filter_width; |
| const int shared_memory_size = |
| kBlockSlices * (tile_pixels + filter_pixels) * sizeof(DType); |
| const int num_outputs = |
| args.batch * args.out_height * args.out_width * args.out_channel; |
| int block_count = std::min(num_outputs/(block_dim.x * block_dim.y * block_dim.z) + 1, |
| (unsigned)mshadow::cuda::kMaxGridNum); |
| auto s = mshadow::Stream<mxnet::gpu>::GetStream(stream); |
| if (args.filter_height == 3 && args.filter_width == 3) { |
| cuda::DepthwiseConv2dKernelSmall<DType, kDirection, kBlockSlices, kEvenHeight, 3, 3> |
| <<<block_count, block_dim, shared_memory_size, s>>>(args, input, filter, output); |
| } else { |
| cuda::DepthwiseConv2dKernelSmall<DType, kDirection, kBlockSlices, kEvenHeight, -1, -1> |
| <<<block_count, block_dim, shared_memory_size, s>>>(args, input, filter, output); |
| } |
| MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dKernelSmall); |
| } |
| |
| template <typename DType, cuda::DepthwiseConv2dDirection kDirection, int kBlockSlices> |
| void LaunchDepthwiseConv2dGPUSmall(mshadow::Stream<mxnet::gpu> *stream, |
| const DepthwiseArgs args, |
| const DType* input, const DType* filter, DType* output) { |
| if (args.in_height & 1) { |
| LaunchDepthwiseConv2dGPUSmall<DType, kDirection, kBlockSlices, false>( |
| stream, args, input, filter, output); |
| } else { |
| LaunchDepthwiseConv2dGPUSmall<DType, kDirection, kBlockSlices, true>( |
| stream, args, input, filter, output); |
| } |
| } |
| |
| template <typename DType, cuda::DepthwiseConv2dDirection kDirection> |
| void LaunchDepthwiseConv2dGPUSmall(mshadow::Stream<mxnet::gpu> *stream, |
| const DepthwiseArgs args, |
| const DType* input, const DType* filter, DType* output) { |
| // Maximize (power of two) kBlockSlices while keeping a block within 1024 |
| // threads (2 pixels per thread). |
| const int block_pixels = (args.in_height + 1) / 2 * args.in_width; |
| if (block_pixels > 256) { |
| LaunchDepthwiseConv2dGPUSmall<DType, kDirection, 2>(stream, args, input, filter, output); |
| } else if (block_pixels > 128) { |
| LaunchDepthwiseConv2dGPUSmall<DType, kDirection, 4>(stream, args, input, filter, output); |
| } else { |
| LaunchDepthwiseConv2dGPUSmall<DType, kDirection, 8>(stream, args, input, filter, output); |
| } |
| } |
| |
| template <typename DType, int kBlockSlices, int kAccumPixels> |
| bool TryLaunchDepthwiseConv2dBackwardFilterGPUSmall(mshadow::Stream<mxnet::gpu> *stream, |
| const DepthwiseArgs args, |
| const int block_height, |
| const DType* out_grad, |
| const DType* input, |
| DType* filter_grad) { |
| const int tile_width = args.in_width + args.filter_width - 1; |
| const int tile_height = block_height * 2 + args.filter_height - 1; |
| const int tile_pixels = tile_height * tile_width; |
| const int filter_pixels = args.filter_height * args.filter_width; |
| const int shared_memory_size = |
| kBlockSlices * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(DType); |
| if (shared_memory_size > 46 * 1024) { |
| return false; |
| } |
| |
| dim3 block_dim = dim3(args.in_width, block_height, kBlockSlices); |
| const int num_out_grad = |
| args.batch * args.out_height * args.out_width * args.out_channel; |
| int block_count = num_out_grad/(block_dim.x * block_dim.y * block_dim.z) + 1; |
| auto s = mshadow::Stream<mxnet::gpu>::GetStream(stream); |
| if (args.filter_height == 3 && args.filter_width == 3) { |
| cuda::DepthwiseConv2dBackwardFilterKernelSmall<DType, kBlockSlices, kAccumPixels, 3, 3> |
| <<<block_count, block_dim, shared_memory_size, s>>>( |
| args, out_grad, input, filter_grad); |
| } else { |
| cuda::DepthwiseConv2dBackwardFilterKernelSmall<DType, kBlockSlices, kAccumPixels, -1, -1> |
| <<<block_count, block_dim, shared_memory_size, s>>>( |
| args, out_grad, input, filter_grad); |
| } |
| MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dBackwardFilterKernelSmall); |
| return true; |
| } |
| |
| template <typename DType, int kBlockSlices> |
| bool TryLaunchDepthwiseConv2dBackwardFilterGPUSmall(mshadow::Stream<mxnet::gpu> *stream, |
| const DepthwiseArgs args, |
| const int block_height, |
| const DType* out_grad, |
| const DType* input, |
| DType* filter_grad) { |
| // Minimize (power of two) kAccumPixels, while satisfying |
| // kAccumPixels * 32 >= block_height * in_width * kBlockSlices. |
| const int block_pixels = block_height * args.in_width * kBlockSlices; |
| if (block_pixels > 512) { |
| return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, kBlockSlices, 32>( |
| stream, args, block_height, out_grad, input, filter_grad); |
| } else if (block_pixels > 256) { |
| return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, kBlockSlices, 16>( |
| stream, args, block_height, out_grad, input, filter_grad); |
| } else { |
| return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, kBlockSlices, 8>( |
| stream, args, block_height, out_grad, input, filter_grad); |
| } |
| } |
| |
| template <typename DType> |
| bool TryLaunchDepthwiseConv2dBackwardFilterGPUSmall(mshadow::Stream<mxnet::gpu> *stream, |
| const DepthwiseArgs args, |
| const DType* out_grad, |
| const DType* input, |
| DType* filter_grad) { |
| // Maximize (power of two) kBlockSlices while keeping a block within 1024 |
| // threads (2 pixels per thread). |
| int block_slices = 8; |
| int block_height = (args.in_height + 1) / 2; |
| int round_mask = 1; |
| for (; block_slices > 1; block_slices /= 2) { |
| // args.in_width * block_height * kBlockSlices must be multiple of 32. |
| for (; block_height * args.in_width * block_slices & 31; |
| round_mask = round_mask * 2 + 1) { |
| block_height = block_height + round_mask & ~round_mask; |
| } |
| int block_size = block_height * args.in_width * block_slices; |
| if (block_size <= 1024) { |
| break; |
| } |
| } |
| |
| if (!CanLaunchDepthwiseConv2dBackwardFilterGPUSmall(args, block_height)) { |
| return false; |
| } |
| |
| switch (block_slices) { |
| case 8: |
| return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, 8>( |
| stream, args, block_height, out_grad, input, filter_grad); |
| case 4: |
| return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, 4>( |
| stream, args, block_height, out_grad, input, filter_grad); |
| case 2: |
| return TryLaunchDepthwiseConv2dBackwardFilterGPUSmall<DType, 2>( |
| stream, args, block_height, out_grad, input, filter_grad); |
| default: |
| return false; |
| } |
| } |
| |
| } // namespace depthwise_conv |
| } // namespace tf |
| |
| #endif // MXNET_OPERATOR_NN_DEPTHWISE_CONVOLUTION_TF_CUH_ |