blob: 1a08526a68b4ad608cb97794ba228ad7914c13ac [file] [log] [blame]
/*********************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
************************************************************/
#include "cudnn_lrn.h"
#ifdef USE_CUDNN
#include "cudnn_utils.h"
namespace singa {
RegisterLayerClass(cudnn_lrn, CudnnLRN);
CudnnLRN::~CudnnLRN() {
if (has_init_cudnn_) {
CUDNN_CHECK(cudnnDestroyLRNDescriptor(lrn_desc_));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(shape_desc_));
}
}
void CudnnLRN::InitCudnn(const Shape& shape, DataType dtype) {
CHECK_EQ(shape.size(), 4u);
if (!has_init_cudnn_) {
mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
CUDNN_CHECK(cudnnCreateLRNDescriptor(&lrn_desc_));
CUDNN_CHECK(cudnnSetLRNDescriptor(lrn_desc_, local_size_, alpha_, beta_, k_));
}
CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), shape[0],
shape[1], shape[2], shape[3]));
has_init_cudnn_ = true;
}
const Tensor CudnnLRN::Forward(int flag, const Tensor& input) {
auto shape = input.shape();
auto dtype = input.data_type();
if (!has_init_cudnn_) {
InitCudnn(shape, dtype);
} else {
int n, c, h, w, s;
cudnnDataType_t type;
CUDNN_CHECK(cudnnGetTensor4dDescriptor(shape_desc_, &type,
&n, &c, &h, &w, &s, &s, &s, &s));
if (shape[0] != static_cast<size_t>(n))
InitCudnn(shape, dtype);
CHECK(input.shape(1) == static_cast<size_t>(c)
&& input.shape(2) == static_cast<size_t>(h)
&& input.shape(3) == static_cast<size_t>(w))
<< "input sample shape should not change"
<< "previous shape " << c << ", " << h << ", " << w
<< "current shape " << input.shape(1) << ", " << input.shape(2) << ", "
<< input.shape(3);
}
Tensor output;
output.ResetLike(input);
output.device()->Exec([=](Context* ctx) {
Block* inblock = input.block(), * outblock = output.block();
const float alpha = 1.0f, beta = 0.0f;
CUDNN_CHECK(cudnnLRNCrossChannelForward(
ctx->cudnn_handle, this->lrn_desc_, this->mode_, &alpha,
this->shape_desc_, inblock->data(), &beta, this->shape_desc_,
outblock->mutable_data()));
}, {input.block()}, {output.block()});
if (flag & kTrain) {
buf_.push(input);
buf_.push(output);
}
return output;
}
const std::pair<Tensor, vector<Tensor>> CudnnLRN::Backward(int flag,
const Tensor& grad) {
vector<Tensor> param_grad;
Tensor dx;
CHECK(!buf_.empty());
Tensor output = buf_.top();
buf_.pop();
Tensor input = buf_.top();
buf_.pop();
if ((flag & kTrain) == kTrain) {
dx.ResetLike(grad);
dx.device()->Exec([=](Context* ctx) {
Block* dyblock = grad.block(), * dxblock = dx.block();
Block* yblock = output.block(), * xblock = input.block();
float alpha = 1.0f, beta = 0.0f;
CUDNN_CHECK(cudnnLRNCrossChannelBackward(
ctx->cudnn_handle, this->lrn_desc_, this->mode_, &alpha,
this->shape_desc_, yblock->data(), this->shape_desc_, dyblock->data(),
this->shape_desc_, xblock->data(), &beta, this->shape_desc_,
dxblock->mutable_data()));
}, {output.block(), grad.block(), input.block()}, {dx.block()});
} else {
LOG(ERROR) << "Do not call backward for evaluation phase";
}
return std::make_pair(dx, param_grad);
}
} // namespace
#endif // USE_CUDNN