| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file cudnn_spatial_transformer-inl.h |
| * \brief |
| * \author Wei Wu |
| */ |
| #ifndef MXNET_OPERATOR_CUDNN_SPATIAL_TRANSFORMER_INL_H_ |
| #define MXNET_OPERATOR_CUDNN_SPATIAL_TRANSFORMER_INL_H_ |
| |
| #include <algorithm> |
| #include <vector> |
| #include "./spatial_transformer-inl.h" |
| namespace mxnet { |
| namespace op { |
| #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 |
| template<typename DType> |
| class CuDNNSpatialTransformerOp : public Operator { |
| public: |
| explicit CuDNNSpatialTransformerOp(SpatialTransformerParam param) { |
| this->param_ = param; |
| init_cudnn_ = false; |
| dtype_ = mshadow::DataType<DType>::kCudnnFlag; |
| if (param_.sampler_type == st::kBilinear) { |
| sampler_ = CUDNN_SAMPLER_BILINEAR; |
| } |
| } |
| |
| ~CuDNNSpatialTransformerOp() { |
| if (init_cudnn_) { |
| CUDNN_CALL(cudnnDestroySpatialTransformerDescriptor(st_desc_)); |
| CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); |
| CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); |
| } |
| } |
| |
| virtual void Forward(const OpContext &ctx, |
| const std::vector<TBlob> &in_data, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &out_data, |
| const std::vector<TBlob> &aux_args) { |
| using namespace mshadow; |
| CHECK_EQ(in_data.size(), 2U); |
| CHECK_EQ(out_data.size(), 3U); |
| Stream<gpu> *s = ctx.get_stream<gpu>(); |
| Tensor<gpu, 4, DType> data = in_data[st::kData].get<gpu, 4, DType>(s); |
| Tensor<gpu, 4, DType> out = out_data[st::kOut].get<gpu, 4, DType>(s); |
| Shape<3> loc_shape = Shape3(data.size(0), 2, 3); |
| Shape<4> grid_shape = Shape4(out.size(0), out.size(2), out.size(3), 2); |
| Tensor<gpu, 3, DType> loc = in_data[st::kLoc].get_with_shape<gpu, 3, DType>(loc_shape, s); |
| Tensor<gpu, 4, DType> grid = out_data[st::kGridSrc] |
| .get_with_shape<gpu, 4, DType>(grid_shape, s); |
| if (!init_cudnn_) { |
| Init(s, in_data, out_data); |
| } |
| CHECK_EQ(data.CheckContiguous(), true); |
| CHECK_EQ(out.CheckContiguous(), true); |
| typename DataType<DType>::ScaleType alpha = 1.0f; |
| typename DataType<DType>::ScaleType beta = 0.0f; |
| if (param_.transform_type == st::kAffine) { |
| CUDNN_CALL(cudnnSpatialTfGridGeneratorForward(s->dnn_handle_, |
| st_desc_, |
| loc.dptr_, |
| grid.dptr_)); |
| } |
| CUDNN_CALL(cudnnSpatialTfSamplerForward(s->dnn_handle_, |
| st_desc_, |
| &alpha, |
| in_desc_, |
| data.dptr_, |
| grid.dptr_, |
| &beta, |
| out_desc_, |
| out.dptr_)); |
| } |
| |
| virtual void Backward(const OpContext &ctx, |
| const std::vector<TBlob> &out_grad, |
| const std::vector<TBlob> &in_data, |
| const std::vector<TBlob> &out_data, |
| const std::vector<OpReqType> &req, |
| const std::vector<TBlob> &in_grad, |
| const std::vector<TBlob> &aux_args) { |
| using namespace mshadow; |
| CHECK_EQ(in_data.size(), 2U); |
| CHECK_EQ(out_data.size(), 3U); |
| CHECK_EQ(out_grad.size(), 1U); |
| Stream<gpu> *s = ctx.get_stream<gpu>(); |
| Tensor<gpu, 4, DType> data = in_data[st::kData].get<gpu, 4, DType>(s); |
| Tensor<gpu, 4, DType> grad = out_grad[st::kOut].get<gpu, 4, DType>(s); |
| Tensor<gpu, 4, DType> ddata = in_grad[st::kData].get<gpu, 4, DType>(s); |
| Shape<3> loc_shape = Shape3(data.size(0), 2, 3); |
| Shape<4> grid_shape = Shape4(grad.size(0), grad.size(2), grad.size(3), 2); |
| Tensor<gpu, 3, DType> dloc = in_grad[st::kLoc].get_with_shape<gpu, 3, DType>(loc_shape, s); |
| Tensor<gpu, 4, DType> grid = out_data[st::kGridSrc] |
| .get_with_shape<gpu, 4, DType>(grid_shape, s); |
| // do not use out_grad[st::kGridSrc], because dgrid is a intermediate tensor, and not include in |
| // DeclareBackwardDependency, another, we can we reuse grid for inplace operator |
| typename DataType<DType>::ScaleType alpha = 1.0f; |
| typename DataType<DType>::ScaleType beta = 0.0f; |
| typename DataType<DType>::ScaleType alpha_dgrid = 1.0f; |
| typename DataType<DType>::ScaleType beta_dgrid = 0.0f; |
| CUDNN_CALL(cudnnSpatialTfSamplerBackward(s->dnn_handle_, |
| st_desc_, |
| &alpha, |
| in_desc_, |
| data.dptr_, |
| &beta, |
| in_desc_/*reuse in_desc_*/, |
| ddata.dptr_/*output*/, |
| &alpha_dgrid, |
| out_desc_/*reuse out_desc_*/, |
| grad.dptr_, |
| grid.dptr_, |
| &beta_dgrid, |
| grid.dptr_)); |
| if (param_.transform_type == st::kAffine) { |
| CUDNN_CALL(cudnnSpatialTfGridGeneratorBackward(s->dnn_handle_, |
| st_desc_, |
| grid.dptr_, |
| dloc.dptr_/*out*/)); |
| } |
| } |
| |
| private: |
| inline void Init(mshadow::Stream<gpu> *s, |
| const std::vector<TBlob> &in_data, |
| const std::vector<TBlob> &out_data) { |
| using namespace mshadow; |
| #if CUDNN_MAJOR >= 5 |
| format_ = CUDNN_TENSOR_NCHW; |
| #endif |
| CHECK_EQ(in_data.size(), 2U); |
| CHECK_EQ(out_data.size(), 3U); |
| if (!init_cudnn_) { |
| init_cudnn_ = true; |
| Tensor<gpu, 4, DType> data = in_data[st::kData].get<gpu, 4, DType>(s); |
| Tensor<gpu, 4, DType> out = out_data[st::kOut].get<gpu, 4, DType>(s); |
| CUDNN_CALL(cudnnCreateSpatialTransformerDescriptor(&st_desc_)); |
| CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); |
| CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); |
| CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_, |
| format_, |
| dtype_, |
| data.size(0), |
| data.size(1), |
| data.size(2), |
| data.size(3))); |
| CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, |
| format_, |
| dtype_, |
| out.size(0), |
| out.size(1), |
| out.size(2), |
| out.size(3))); |
| if (param_.sampler_type == st::kBilinear) { |
| int dim[] = {static_cast<int>(out.size(0)), static_cast<int>(out.size(1)), |
| static_cast<int>(out.size(2)), static_cast<int>(out.size(3))}; |
| CUDNN_CALL(cudnnSetSpatialTransformerNdDescriptor(st_desc_, |
| sampler_, |
| dtype_, |
| 4, |
| dim)); |
| } |
| } |
| } |
| |
| bool init_cudnn_; |
| cudnnDataType_t dtype_; |
| cudnnSpatialTransformerDescriptor_t st_desc_; |
| cudnnTensorDescriptor_t in_desc_; |
| cudnnTensorDescriptor_t out_desc_; |
| cudnnSamplerType_t sampler_; |
| #if CUDNN_MAJOR >= 5 |
| cudnnTensorFormat_t format_; |
| #endif |
| SpatialTransformerParam param_; |
| }; |
| #endif // __CUDACC__ && CUDNN |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_OPERATOR_CUDNN_SPATIAL_TRANSFORMER_INL_H_ |