blob: cd463ae2e78bc8903f8322d2a42aa2f9de3cc5dc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file image_random.cu
* \brief GPU Implementation of image transformation operators
*/
#include <cuda_runtime_api.h>
#include "./image_random-inl.h"
#include "../elemwise_op_common.h"
namespace mxnet {
namespace op {
namespace image {
using namespace mshadow;
// ToTensor Kernel for 3D input
/*
* In order to not generate the code that uses too many
* registers (resulting in too many resources requested
* error) we need to tell the compiler that we will be
* launching this kernel with cuda::kMaxThreadsPerBlock
* threads per block. Setting __launch_bounds__ ensures
* that such configuration can always be launched.
*/
template <typename xpu, typename Dtype>
__global__ void __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
ToTensorCudaKernel(const Tensor<xpu, 3, Dtype> input,
const Tensor<xpu, 3, float> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float normalize_factor) {
// We process one image per thread block.
// In 3D case, we have only 1 block i.e., blockIdx.x
// We do not use it.
for (int c = 0; c < C; ++c) {
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[c][h][w], req, input[h][w][c] / normalize_factor);
}
}
}
}
// ToTensor Kernel for 4D input
template <typename xpu, typename Dtype>
__global__ void __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
ToTensorCudaKernel(const Tensor<xpu, 4, Dtype> input,
const Tensor<xpu, 4, float> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float normalize_factor) {
// We process one image per thread block.
const int n = blockIdx.x;
for (int c = 0; c < C; ++c) {
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[n][c][h][w], req, input[n][h][w][c] / normalize_factor);
}
}
}
}
template <typename DType, typename T1, typename T2>
void ToTensorImplCUDA(mshadow::Stream<gpu>* s,
const T1 input,
const T2 output,
const int req,
const float normalize_factor) {
int blocks, H, W, C, N;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (std::is_same<T1, Tensor<gpu, 3, DType>>::value) {
// 3D Input - (H, W, C)
N = 0;
H = input.size(0);
W = input.size(1);
C = input.size(2);
blocks = 1;
} else {
// 4D Input - (N, H, W, C)
N = input.size(0);
H = input.size(1);
W = input.size(2);
C = input.size(3);
blocks = N > 0 ? N : 1;
}
ToTensorCudaKernel<gpu, DType>
<<<blocks, dim3(32, 32), 0, stream>>>(input, output, req, N, H, W, C, normalize_factor);
MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel);
}
// Normalize Forward CUDA Kernel
template <typename xpu, typename DType>
__global__ void __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
NormalizeCudaKernel(const DType* input,
DType* output,
const int req,
const int N,
const int C,
const int H,
const int W,
const float mean_d0,
const float mean_d1,
const float mean_d2,
const float std_d0,
const float std_d1,
const float std_d2) {
// We process one image per thread block.
const int n = blockIdx.x;
const int length = H * W;
const int step = C * length * n;
float mean = mean_d0;
float std = std_d0;
for (int c = 0; c < C; ++c) {
switch (c) {
case 0:
break;
case 1:
mean = mean_d1;
std = std_d1;
break;
case 2:
mean = mean_d2;
std = std_d2;
break;
}
for (int i = threadIdx.x; i < length; i += blockDim.x) {
KERNEL_ASSIGN(*(output + step + i + (c * length)),
req,
(*(input + step + i + (c * length)) - mean) / std);
}
}
}
template <typename DType>
void NormalizeImplCUDA(mshadow::Stream<gpu>* s,
const DType* input,
DType* output,
const int req,
const int N,
const int C,
const int H,
const int W,
const float mean_d0,
const float mean_d1,
const float mean_d2,
const float std_d0,
const float std_d1,
const float std_d2) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
NormalizeCudaKernel<gpu, DType>
// 1 image per block. N is batch size.
<<<N, dim3(cuda::kMaxThreadsPerBlock, 1), 0, stream>>>(
input, output, req, N, C, H, W, mean_d0, mean_d1, mean_d2, std_d0, std_d1, std_d2);
MSHADOW_CUDA_POST_KERNEL_CHECK(NormalizeCudaKernel);
}
// Normalize Backward Kernel
template <typename xpu, typename DType>
__global__ void __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
NormalizeBackwardCudaKernel(const DType* out_grad,
DType* in_grad,
const int req,
const int N,
const int C,
const int H,
const int W,
const float std_d0,
const float std_d1,
const float std_d2) {
// We process one image per thread block.
const int n = blockIdx.x;
const int length = H * W;
const int step = C * length * n;
float std = std_d0;
for (int c = 0; c < C; ++c) {
switch (c) {
case 0:
break;
case 1:
std = std_d1;
break;
case 2:
std = std_d2;
break;
}
for (int i = threadIdx.x; i < length; i += blockDim.x) {
KERNEL_ASSIGN(*(in_grad + step + i + (c * length)),
req,
*(out_grad + step + i + (c * length)) * (1.0 / std));
}
}
}
template <typename DType>
void NormalizeBackwardImplCUDA(mshadow::Stream<gpu>* s,
const DType* out_grad,
DType* in_grad,
const int req,
const int N,
const int C,
const int H,
const int W,
const float std_d0,
const float std_d1,
const float std_d2) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
NormalizeBackwardCudaKernel<gpu, DType>
// 1 image per block. N is batch size.
<<<N, dim3(cuda::kMaxThreadsPerBlock, 1), 0, stream>>>(
out_grad, in_grad, req, N, C, H, W, std_d0, std_d1, std_d2);
MSHADOW_CUDA_POST_KERNEL_CHECK(NormalizeBackwardCudaKernel);
}
NNVM_REGISTER_OP(_image_to_tensor).set_attr<FCompute>("FCompute<gpu>", ToTensorOpForward<gpu>);
NNVM_REGISTER_OP(_image_normalize).set_attr<FCompute>("FCompute<gpu>", NormalizeOpForward<gpu>);
NNVM_REGISTER_OP(_backward_image_normalize)
.set_attr<FCompute>("FCompute<gpu>", NormalizeOpBackward<gpu>);
} // namespace image
} // namespace op
} // namespace mxnet