blob: 90e76581fa86c0cd1fcaf4f27d177bbe28f0ef8d [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file cudnn_batch_norm-inl.h
* \brief
* \author Junyuan Xie
*/
#ifndef MXNET_OPERATOR_CUDNN_BATCH_NORM_INL_H_
#define MXNET_OPERATOR_CUDNN_BATCH_NORM_INL_H_
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "batch_norm-inl.h"
namespace mxnet {
namespace op {
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4
namespace cudnnbatchnorm {
enum CuDNNBatchNormOpInputs {kData, kGamma, kBeta};
enum CuDNNBatchNormOpOutputs {kOut, kMean, kInvVar};
enum CuDNNBatchNormOpAuxiliary {kMovingMean, kMovingInvVar};
} // namespace cudnnbatchnorm
#if defined(__CUDACC__)
template<typename DType>
class CuDNNBatchNormOp : public Operator {
public:
explicit CuDNNBatchNormOp(BatchNormParam param) {
using namespace mshadow;
CHECK_GT(param.eps, CUDNN_BN_MIN_EPSILON)
<< "CuDNN requires eps to be greater than " << CUDNN_BN_MIN_EPSILON;
this->param_ = param;
init_cudnn_ = false;
dtype_ = DataType<DType>::kCudnnFlag;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType<DType>::kFlag;
}
~CuDNNBatchNormOp() {
if (init_cudnn_) {
CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc_));
}
}
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(aux_states.size(), 2U);
if (ctx.is_train) {
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(req.size(), 3U);
} else {
CHECK_GE(out_data.size(), 1U);
CHECK_GE(req.size(), 1U);
}
CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4);
if (!init_cudnn_) {
for (int i = 0; i < 4; ++i) {
if (i < in_data[cudnnbatchnorm::kData].ndim()) {
shape_[i] = in_data[cudnnbatchnorm::kData].shape_[i];
} else {
shape_[i] = 1;
}
}
CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_));
CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
CUDNN_TENSOR_NCHW,
dtype_,
shape_[0],
shape_[1],
shape_[2],
shape_[3]));
CUDNN_CALL(cudnnDeriveBNTensorDescriptor(mean_desc_,
io_desc_,
CUDNN_BATCHNORM_SPATIAL));
init_cudnn_ = true;
}
Stream<gpu> *s = ctx.get_stream<gpu>();
Tensor<gpu, 4, DType> x =
in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
Tensor<gpu, 4, DType> y =
out_data[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, DType>(shape_, s);
MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
Tensor<gpu, 1, DTypeParam> gamma =
in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> beta =
in_data[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> moving_mean =
aux_states[cudnnbatchnorm::kMovingMean]
.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> moving_inv_var =
aux_states[cudnnbatchnorm::kMovingInvVar]
.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
typename DataType<DType>::ScaleType a = 1.0f;
typename DataType<DType>::ScaleType b = 0.0f;
if (param_.fix_gamma) gamma = 1.f;
if (ctx.is_train) {
Tensor<gpu, 1, DTypeParam> save_mean =
out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> save_inv_var =
out_data[cudnnbatchnorm::kInvVar]
.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_,
CUDNN_BATCHNORM_SPATIAL,
&a,
&b,
io_desc_,
x.dptr_,
io_desc_,
y.dptr_,
mean_desc_,
gamma.dptr_,
beta.dptr_,
1 - param_.momentum,
moving_mean.dptr_,
moving_inv_var.dptr_,
param_.eps,
save_mean.dptr_,
save_inv_var.dptr_));
} else {
CUDNN_CALL(cudnnBatchNormalizationForwardInference(s->dnn_handle_,
CUDNN_BATCHNORM_SPATIAL,
&a,
&b,
io_desc_,
x.dptr_,
io_desc_,
y.dptr_,
mean_desc_,
gamma.dptr_,
beta.dptr_,
moving_mean.dptr_,
moving_inv_var.dptr_,
param_.eps));
}
})
}
virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1U);
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(in_grad.size(), 3U);
CHECK(ctx.is_train && !param_.use_global_stats)
<< "use global statistics is not yet supported in CuDNNBatchNorm";
Stream<gpu> *s = ctx.get_stream<gpu>();
Tensor<gpu, 4, DType> x =
in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
Tensor<gpu, 4, DType> dx =
in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
Tensor<gpu, 4, DType> dy =
out_grad[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, DType>(shape_, s);
#if CUDNN_VERSION >= 4007
MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
Tensor<gpu, 1, DTypeParam> gamma =
in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> dbeta =
in_grad[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> dgamma =
in_grad[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> save_mean =
out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> save_inv_var =
out_data[cudnnbatchnorm::kInvVar].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
typename DataType<DType>::ScaleType a = 1.0f;
typename DataType<DType>::ScaleType b = 0.0f;
typename DataType<DType>::ScaleType b_add = 1.0f;
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
if (param_.fix_gamma) gamma = 1.f;
CUDNN_CALL(cudnnBatchNormalizationBackward(
s->dnn_handle_,
CUDNN_BATCHNORM_SPATIAL,
&a,
&b,
&a,
req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
io_desc_,
x.dptr_,
io_desc_,
dy.dptr_,
io_desc_,
dx.dptr_,
mean_desc_,
gamma.dptr_,
dgamma.dptr_,
dbeta.dptr_,
param_.eps,
save_mean.dptr_,
save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
})
#else // CUDNN_VERSION < 4007
MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
Tensor<gpu, 1, DTypeParam> gamma =
in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> dbeta =
in_grad[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> dgamma =
in_grad[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> save_mean =
out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Tensor<gpu, 1, DTypeParam> save_inv_var =
out_data[cudnnbatchnorm::kInvVar].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
typename DataType<DType>::ScaleType a = 1.0f;
typename DataType<DType>::ScaleType b = 0.0f;
typename DataType<DType>::ScaleType b_add = 1.0f;
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
if (param_.fix_gamma) gamma = 1.f;
CUDNN_CALL(cudnnBatchNormalizationBackward(s->dnn_handle_,
CUDNN_BATCHNORM_SPATIAL,
&a,
&b,
io_desc_,
x.dptr_,
io_desc_,
dy.dptr_,
io_desc_,
dx.dptr_,
mean_desc_,
gamma.dptr_,
dgamma.dptr_,
dbeta.dptr_,
param_.eps,
save_mean.dptr_,
save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
})
#endif
}
private:
bool init_cudnn_;
cudnnDataType_t dtype_;
int dtype_param_;
cudnnTensorDescriptor_t io_desc_, mean_desc_;
mshadow::Shape<4> shape_;
BatchNormParam param_;
};
#endif // defined(__CUDACC__)
template<typename xpu>
Operator *CreateOp_CuDNNv4(BatchNormParam param);
#if DMLC_USE_CXX11
class CuDNNBatchNormProp : public OperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}
std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const TShape &dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false;
in_shape->at(1) = TShape(Shape1(dshape[1]));
in_shape->at(2) = TShape(Shape1(dshape[1]));
out_shape->clear();
out_shape->push_back(dshape);
out_shape->push_back(Shape1(dshape[1]));
out_shape->push_back(Shape1(dshape[1]));
aux_shape->clear();
aux_shape->push_back(Shape1(dshape[1]));
aux_shape->push_back(Shape1(dshape[1]));
return true;
}
OperatorProperty* Copy() const override {
auto ptr = new CuDNNBatchNormProp();
ptr->param_ = param_;
return ptr;
}
std::string TypeString() const override {
return "CuDNNBatchNorm";
}
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
return {out_grad[cudnnbatchnorm::kOut],
out_data[cudnnbatchnorm::kMean],
out_data[cudnnbatchnorm::kInvVar],
in_data[cudnnbatchnorm::kData],
in_data[cudnnbatchnorm::kGamma]
};
}
int NumVisibleOutputs() const override {
return 1;
}
int NumOutputs() const override {
return 3;
}
std::vector<std::string> ListArguments() const override {
return {"data", "gamma", "beta"};
}
std::vector<std::string> ListOutputs() const override {
return {"output", "mean", "inv_var"};
}
std::vector<std::string> ListAuxiliaryStates() const override {
return {"moving_mean", "moving_inv_var"};
}
Operator* CreateOperator(Context ctx) const override;
private:
BatchNormParam param_;
}; // class CuDNNBatchNormProp
#endif // DMLC_USE_CXX11
#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CUDNN_BATCH_NORM_INL_H_