blob: f29679c566029c9fc39533d73c4942b5d0be5023 [file] [log] [blame]
/*********************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
************************************************************/
#include "cudnn_batchnorm.h"
#ifdef USE_CUDNN
namespace singa {
RegisterLayerClass(cudnn_batchnorm, CudnnBatchNorm);
CudnnBatchNorm::~CudnnBatchNorm() {
if (has_init_cudnn_) {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(shape_desc_));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(param_desc_));
}
}
void CudnnBatchNorm::ToDevice(std::shared_ptr<Device> device) {
BatchNorm::ToDevice(device);
resultSaveMean_.ToDevice(device);
resultSaveVariance_.ToDevice(device);
}
void CudnnBatchNorm::Setup(const Shape& in_sample, const LayerConf& conf) {
BatchNorm::Setup(in_sample, conf);
bnScale_.Reshape(Shape{channels_});
bnBias_.Reshape(Shape{channels_});
dbnScale_.Reshape(Shape{channels_});
dbnBias_.Reshape(Shape{channels_});
runningMean_.Reshape(Shape{channels_});
runningVariance_.Reshape(Shape{channels_});
resultSaveMean_.Reshape(Shape{channels_});
resultSaveVariance_.Reshape(Shape{channels_});
}
void CudnnBatchNorm::InitCudnn(const Shape& shape, DataType dtype) {
CHECK(!has_init_cudnn_);
mode_ = CUDNN_BATCHNORM_SPATIAL;
CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc_));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc_));
CHECK_EQ(shape.size(), 4u);
CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc_,
CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype),
shape[0],
shape[1],
shape[2],
shape[3]));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc_,
CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype),
1,
shape[1],
1,
1));
has_init_cudnn_ = true;
}
const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
auto shape = input.shape();
auto dtype = input.data_type();
Tensor output;
Tensor x;
if(is_2d_)
x = Reshape(input, Shape{shape.at(0), shape.at(1), 1, 1});
else
x = input;
shape = x.shape();
if (!has_init_cudnn_)
InitCudnn(shape, dtype);
// TODO(wangji): check device id of input and params
output.ResetLike(x);
if ((flag & kTrain) == kTrain) {
output.device()->Exec(
[=](Context* ctx) {
Block *inBlock = x.block(), *outBlock = output.block(),
*saveMeanBlock = resultSaveMean_.block(),
*saveVarBlock = resultSaveVariance_.block(),
*runningMeanBlock = runningMean_.block(),
*runningVarBlock = runningVariance_.block(),
*bnScaleBlock = bnScale_.block(),
*bnBiasBlock = bnBias_.block();
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
ctx->cudnn_handle,
this->mode_,
&alpha,
&beta,
shape_desc_,
inBlock->data(),
shape_desc_,
outBlock->mutable_data(),
param_desc_,
bnScaleBlock->data(),
bnBiasBlock->data(),
factor_,
runningMeanBlock->mutable_data(),
runningVarBlock->mutable_data(),
epsilon,
saveMeanBlock->mutable_data(),
saveVarBlock->mutable_data()));
},
{x.block(),
bnScale_.block(),
bnBias_.block()},
{output.block(),
runningMean_.block(),
runningVariance_.block(),
resultSaveMean_.block(),
resultSaveVariance_.block()});
buf_.push(x);
} else {
output.device()->Exec(
[=](Context* ctx) {
Block *inBlock = x.block(), *outBlock = output.block(),
*runningMeanBlock = runningMean_.block(),
*runningVarBlock = runningVariance_.block(),
*bnScaleBlock = bnScale_.block(),
*bnBiasBlock = bnBias_.block();
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
ctx->cudnn_handle,
this->mode_,
&alpha,
&beta,
shape_desc_,
inBlock->data(),
shape_desc_,
outBlock->mutable_data(),
param_desc_,
bnScaleBlock->data(),
bnBiasBlock->data(),
runningMeanBlock->data(),
runningVarBlock->data(),
epsilon));
},
{x.block(),
bnScale_.block(),
bnBias_.block(),
runningMean_.block(),
runningVariance_.block()},
{output.block()});
}
if (is_2d_)
output.Reshape(Shape{shape.at(0), shape.at(1)});
return output;
}
const std::pair<Tensor, vector<Tensor>> CudnnBatchNorm::Backward(
int flag, const Tensor& grad) {
vector <Tensor> param_grad;
Tensor dx;
if ((flag & kTrain) == kTrain) {
Tensor x = buf_.top();
buf_.pop();
dx.ResetLike(grad);
dx.device()->Exec(
[=](Context* ctx) {
Block *dyblock = grad.block(), *dxblock = dx.block(),
*xblock = x.block(),
*bnScaleBlock = bnScale_.block(),
*dbnScaleBlock = dbnScale_.block(),
*dbnBiasBlock = dbnBias_.block(),
*saveMeanBlock = resultSaveMean_.block(),
*saveVarBlock = resultSaveVariance_.block();
const float alpha = 1.0f, beta = .0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationBackward(ctx->cudnn_handle,
this->mode_,
&alpha,
&beta,
&alpha,
&beta,
shape_desc_,
xblock->data(),
shape_desc_,
dyblock->data(),
shape_desc_,
dxblock->mutable_data(),
param_desc_,
bnScaleBlock->data(),
dbnScaleBlock->mutable_data(),
dbnBiasBlock->mutable_data(),
epsilon,
saveMeanBlock->data(),
saveVarBlock->data()));
},
{dx.block(),
grad.block(),
bnScale_.block(),
resultSaveMean_.block(),
resultSaveVariance_.block()},
{dx.block(),
dbnScale_.block(),
dbnBias_.block()});
} else {
LOG(ERROR) << "Do not call backward for evaluation phase";
}
param_grad.push_back(dbnScale_);
param_grad.push_back(dbnBias_);
Tensor dummy;
dummy.ResetLike(dbnScale_);
dummy.SetValue(.0f);
param_grad.push_back(dummy);
param_grad.push_back(dummy);
if (is_2d_)
dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
return std::make_pair(dx, param_grad);
}
} // namespace
#endif // USE_CUDNN