blob: dffac1e96e82fdc161305f40cb8ab3462c234c00 [file] [log] [blame]
/*********************************************************
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
************************************************************/
#include "batchnorm.h"
#include <cctype>
namespace singa {
BatchNormHandle::BatchNormHandle(const float momentum, const Tensor& input) {
factor = momentum;
batchsize = input.shape(0);
channels = input.shape(1);
if (input.nDim() == 4u) {
height = input.shape().at(2);
width = input.shape().at(3);
is_2d = false;
} else if (input.nDim() == 2u) {
height = 1;
width = 1;
is_2d = true;
} else {
LOG(FATAL) << "The dimension of input should either be 4D or 2D.";
}
#ifdef USE_DNNL
if (input.device()->lang() == kCpp) {
use_dnnl = true;
epsilon = 1e-5f;
x_dims = dnnl::memory::dims(input.shape().begin(), input.shape().end());
// support f32 only
auto dtype_ = memory::data_type::f32;
memory::format_tag format_tag_ = get_dnnl_format_tag(input);
x_md = dnnl::memory::desc({x_dims}, dtype_, format_tag_);
// add to
bn_fwd_training_d = new dnnl::batch_normalization_forward::desc(
dnnl::prop_kind::forward_training, x_md, epsilon,
dnnl::normalization_flags::use_scale_shift);
auto eng = input.device()->context(0)->dnnl_engine;
bn_fwd_training_pd = new dnnl::batch_normalization_forward::primitive_desc(
*bn_fwd_training_d, eng);
}
#endif // USE_DNNL
};
BatchNormHandle::~BatchNormHandle() {
#ifdef USE_DNNL
if (use_dnnl) {
delete (bn_fwd_training_d);
delete (bn_fwd_training_pd);
}
#endif // USE_DNNL
}
#ifdef USE_DNNL
Tensor CpuBatchNormForwardInference(const BatchNormHandle& bnh, const Tensor& x,
const Tensor& bnScale, const Tensor& bnBias,
Tensor& running_mean, Tensor& running_var) {
CHECK_EQ(x.device()->lang(), kCpp);
Tensor y;
y.ResetLike(x);
Tensor w = get_bn_weight_from(bnScale, bnBias);
y.device()->Exec(
[y, w, x, &running_mean, &running_var, &bnh](Context* ctx) mutable {
auto eng = ctx->dnnl_engine;
using namespace dnnl;
auto x_mem = memory(bnh.x_md, eng, x.block()->mutable_data());
auto y_mem = memory(bnh.x_md, eng, y.block()->mutable_data());
// indicates using scale&bias and running mean&var
auto flags_ = normalization_flags::use_scale_shift |
normalization_flags::use_global_stats;
auto bn_fwd_d = batch_normalization_forward::desc(
prop_kind::forward_inference, bnh.x_md, bnh.epsilon, flags_);
auto bn_fwd_pd =
batch_normalization_forward::primitive_desc(bn_fwd_d, eng);
auto m_mem = memory(bn_fwd_pd.mean_desc(), eng,
running_mean.block()->mutable_data());
auto v_mem = memory(bn_fwd_pd.variance_desc(), eng,
running_var.block()->mutable_data());
auto w_mem =
memory(bn_fwd_pd.weights_desc(), eng, w.block()->mutable_data());
// execution
batch_normalization_forward(bn_fwd_pd).execute(
ctx->dnnl_stream, {{DNNL_ARG_SRC, x_mem},
{DNNL_ARG_DST, y_mem},
{DNNL_ARG_SCALE_SHIFT, w_mem},
{DNNL_ARG_MEAN, m_mem},
{DNNL_ARG_VARIANCE, v_mem}});
ctx->dnnl_stream.wait();
},
{x.block(), w.block(), running_mean.block(), running_var.block()},
{y.block(), running_mean.block(), running_var.block()}, "CpuBatchNormForwardInference");
return y;
}
const std::vector<Tensor> CpuBatchNormForwardTraining(
const BatchNormHandle& bnh, const Tensor& x, const Tensor& bnScale,
const Tensor& bnBias, Tensor& running_mean, Tensor& running_var) {
CHECK_EQ(x.device()->lang(), kCpp);
Tensor y;
y.ResetLike(x);
// mean and var for local batch
Tensor mean;
mean.ResetLike(running_mean);
Tensor var;
var.ResetLike(running_var);
// combine scale and bias to construct weight tensor in required format for
// backward
Tensor w = get_bn_weight_from(bnScale, bnBias);
y.device()->Exec(
[y, mean, var, w, x, &running_mean, &running_var,
&bnh](Context* ctx) mutable {
auto eng = ctx->dnnl_engine;
using namespace dnnl;
auto x_mem = memory(bnh.x_md, eng, x.block()->mutable_data());
auto y_mem = memory(bnh.x_md, eng, y.block()->mutable_data());
auto m_mem = memory(bnh.bn_fwd_training_pd->mean_desc(), eng,
mean.block()->mutable_data());
auto v_mem = memory(bnh.bn_fwd_training_pd->variance_desc(), eng,
var.block()->mutable_data());
auto w_mem = memory(bnh.bn_fwd_training_pd->weights_desc(), eng,
w.block()->mutable_data());
batch_normalization_forward(*bnh.bn_fwd_training_pd)
.execute(ctx->dnnl_stream, {{DNNL_ARG_SRC, x_mem},
{DNNL_ARG_DST, y_mem},
{DNNL_ARG_SCALE_SHIFT, w_mem},
{DNNL_ARG_MEAN, m_mem},
{DNNL_ARG_VARIANCE, v_mem}});
ctx->dnnl_stream.wait();
// local implemented running mean as mkldnn does not support it yet:
// https://github.com/intel/mkl-dnn/issues/371
// https://github.com/intel/mkl-dnn/issues/517
// https://arxiv.org/pdf/1502.03167.pdf
auto s = x.shape();
s[1] = 1;
float p = Product(s); // for unbiased variance
running_mean = running_mean * (1 - bnh.factor) + mean * bnh.factor;
running_var =
running_var * (1 - bnh.factor) + var * (p / (p - 1)) * bnh.factor;
},
{x.block(), w.block(), running_mean.block(), running_var.block()},
{y.block(), running_mean.block(), running_var.block(), mean.block(),
var.block()}, "CpuBatchNormForwardTraining");
return {y, mean, var};
}
const std::vector<Tensor> CpuBatchNormBackwardx(
const BatchNormHandle& bnh, const Tensor& y, const Tensor& dy,
const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
const Tensor& mean, const Tensor& var) {
CHECK_EQ(x.device()->lang(), kCpp);
CHECK_EQ(y.device()->lang(), kCpp);
CHECK_EQ(dy.device()->lang(), kCpp);
CHECK_EQ(mean.device()->lang(), kCpp);
CHECK_EQ(var.device()->lang(), kCpp);
CHECK_EQ(bnScale.device()->lang(), kCpp);
CHECK_EQ(bnBias.device()->lang(), kCpp);
Tensor dx;
dx.ResetLike(dy);
// combine scale and bias to construct weight tensor in required format for
// backward
Tensor w = get_bn_weight_from(bnScale, bnBias);
// Tensor dw(Shape{bnScale.Size(), 2});
Tensor dw;
dw.ResetLike(w);
dx.device()->Exec(
[w, dw, dx, dy, x, y, mean, var, &bnh](Context* ctx) mutable {
auto eng = ctx->dnnl_engine;
using namespace dnnl;
auto x_mem = memory(bnh.x_md, eng, x.block()->mutable_data());
auto dx_mem = memory(bnh.x_md, eng, dx.block()->mutable_data());
auto y_mem = memory(bnh.x_md, eng, y.block()->mutable_data());
auto dy_mem = memory(bnh.x_md, eng, dy.block()->mutable_data());
auto m_mem = memory(bnh.bn_fwd_training_pd->mean_desc(), eng,
mean.block()->mutable_data());
auto v_mem = memory(bnh.bn_fwd_training_pd->variance_desc(), eng,
var.block()->mutable_data());
auto w_mem = memory(bnh.bn_fwd_training_pd->weights_desc(), eng,
w.block()->mutable_data());
auto bn_bwd_d = batch_normalization_backward::desc(
prop_kind::backward, bnh.x_md, bnh.x_md, bnh.epsilon,
normalization_flags::use_scale_shift);
auto bn_bwd_pd = batch_normalization_backward::primitive_desc(
bn_bwd_d, eng, *bnh.bn_fwd_training_pd);
auto dw_mem = memory(bn_bwd_pd.diff_weights_desc(), eng,
dw.block()->mutable_data());
batch_normalization_backward(bn_bwd_pd).execute(
ctx->dnnl_stream, {{DNNL_ARG_SRC, x_mem},
{DNNL_ARG_DIFF_SRC, dx_mem},
{DNNL_ARG_DIFF_DST, dy_mem},
{DNNL_ARG_MEAN, m_mem},
{DNNL_ARG_VARIANCE, v_mem},
{DNNL_ARG_DIFF_SCALE_SHIFT, dw_mem},
{DNNL_ARG_SCALE_SHIFT, w_mem}});
ctx->dnnl_stream.wait();
},
{x.block(), dy.block(), mean.block(), var.block(), w.block(), y.block()},
{dx.block(), dw.block()}, "CpuBatchNormBackwardx");
singa::Tensor dbnScale(bnScale.shape());
CopyDataToFrom(&dbnScale, dw, bnScale.Size(), 0, 0);
singa::Tensor dbnBias(bnBias.shape());
CopyDataToFrom(&dbnBias, dw, bnBias.Size(), 0, bnScale.Size());
CHECK(dbnScale.nDim() == bnScale.nDim()) << "dbnScale ndim not match bnScale";
CHECK(dbnBias.nDim() == bnBias.nDim()) << "dbnScale ndim not match bnScale";
CHECK(dbnScale.shape()[0] == bnScale.shape()[0])
<< "dbnScale shape not match bnScale";
CHECK(dbnBias.shape()[0] == bnBias.shape()[0])
<< "dbnBias shape not match bnBias";
return {dx, dbnScale, dbnBias};
}
#endif // USE_DNNL
#ifdef USE_CUDNN
CudnnBatchNormHandle::CudnnBatchNormHandle(const float momentum,
const Tensor& input)
: BatchNormHandle(momentum, input) {
if (is_2d) {
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode = CUDNN_BATCHNORM_SPATIAL;
if (const char* env_p = std::getenv("CUDNN_BATCHNORM_ALG")) {
std::string alg = std::string(env_p);
std::transform(alg.begin(), alg.end(), alg.begin(), toupper);
if (alg == "CUDNN_BATCHNORM_SPATIAL_PERSISTENT")
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
LOG(INFO) << " CUDNN_BATCHNORM_ALG: " << alg;
}
}
DataType dtype = input.data_type();
CUDNN_CHECK(cudnnCreateTensorDescriptor(&shape_desc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&param_desc));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(shape_desc, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), batchsize,
channels, height, width));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(param_desc, CUDNN_TENSOR_NCHW,
GetCudnnDataType(dtype), 1, channels,
1, 1));
};
const std::vector<Tensor> GpuBatchNormForwardTraining(
const CudnnBatchNormHandle& cbnh, const Tensor& x, const Tensor& bnScale,
const Tensor& bnBias, Tensor& running_mean, Tensor& running_var) {
CHECK_EQ(x.device()->lang(), kCuda);
CHECK_EQ(bnScale.device()->lang(), kCuda);
CHECK_EQ(bnBias.device()->lang(), kCuda);
CHECK_EQ(running_mean.device()->lang(), kCuda);
CHECK_EQ(running_var.device()->lang(), kCuda);
Tensor mean;
Tensor var;
mean.ResetLike(running_mean);
var.ResetLike(running_var);
Shape shape = x.shape();
Tensor input(x); // for unification of 2d and 4d cases.
if (cbnh.is_2d) input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1});
Tensor output;
output.ResetLike(x);
output.device()->Exec(
[=, &bnScale, &bnBias, &running_mean, &running_var,
&cbnh](Context* ctx) mutable {
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
input.block()->data(), cbnh.shape_desc,
output.block()->mutable_data(), cbnh.param_desc,
bnScale.block()->data(), bnBias.block()->data(), cbnh.factor,
running_mean.block()->mutable_data(),
running_var.block()->mutable_data(), epsilon,
mean.block()->mutable_data(), var.block()->mutable_data()));
},
{input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
running_var.block()},
{output.block(), running_mean.block(), running_var.block(), mean.block(),
var.block()}, "GpuBatchNormForwardTraining");
if (cbnh.is_2d) output.Reshape(Shape{shape.at(0), shape.at(1)});
return {output, mean, var};
}
Tensor GpuBatchNormForwardInference(const CudnnBatchNormHandle& cbnh,
const Tensor& x, const Tensor& bnScale,
const Tensor& bnBias,
const Tensor& running_mean,
const Tensor& running_var) {
CHECK_EQ(x.device()->lang(), kCuda);
CHECK_EQ(bnScale.device()->lang(), kCuda);
CHECK_EQ(bnBias.device()->lang(), kCuda);
CHECK_EQ(running_mean.device()->lang(), kCuda);
CHECK_EQ(running_var.device()->lang(), kCuda);
Shape shape = x.shape();
Tensor input(x); // for unification of 2d and 4d cases.
if (cbnh.is_2d) input.Reshape(Shape{shape.at(0), shape.at(1), 1, 1});
Tensor output;
output.ResetLike(x);
output.device()->Exec(
[=, &bnScale, &bnBias, &running_mean, &running_var,
&cbnh](Context* ctx) mutable {
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
ctx->cudnn_handle, cbnh.mode, &alpha, &beta, cbnh.shape_desc,
input.block()->data(), cbnh.shape_desc,
output.block()->mutable_data(), cbnh.param_desc,
bnScale.block()->data(), bnBias.block()->data(),
running_mean.block()->data(), running_var.block()->data(),
epsilon));
},
{input.block(), bnScale.block(), bnBias.block(), running_mean.block(),
running_var.block()},
{output.block()}, "GpuBatchNormForwardInference");
return output;
}
const std::vector<Tensor> GpuBatchNormBackward(
const CudnnBatchNormHandle& cbnh, const Tensor& dy, const Tensor& x,
const Tensor& bnScale, const Tensor& mean, const Tensor& var) {
CHECK_EQ(dy.device()->lang(), kCuda);
CHECK_EQ(x.device()->lang(), kCuda);
CHECK_EQ(bnScale.device()->lang(), kCuda);
CHECK_EQ(mean.device()->lang(), kCuda);
CHECK_EQ(var.device()->lang(), kCuda);
Tensor dx;
dx.ResetLike(dy);
Tensor dbnScale;
dbnScale.ResetLike(bnScale);
Tensor dbnBias;
dbnBias.ResetLike(bnScale);
dx.device()->Exec(
[=, &bnScale, &cbnh](Context* ctx) mutable {
const float alpha = 1.0f, beta = .0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationBackward(
ctx->cudnn_handle, cbnh.mode, &alpha, &beta, &alpha, &beta,
cbnh.shape_desc, x.block()->data(), cbnh.shape_desc,
dy.block()->data(), cbnh.shape_desc, dx.block()->mutable_data(),
cbnh.param_desc, bnScale.block()->data(),
dbnScale.block()->mutable_data(), dbnBias.block()->mutable_data(),
epsilon, mean.block()->data(), var.block()->data()));
},
{x.block(), dy.block(), bnScale.block(), mean.block(), var.block()},
{dx.block(), dbnScale.block(), dbnBias.block()}, "GpuBatchNormBackward");
if (cbnh.is_2d) dx.Reshape(Shape{dx.shape().at(0), dx.shape().at(1)});
return {dx, dbnScale, dbnBias};
}
#endif // USE_CUDNN
} // namespace singa