blob: 1d27427653b836ab5ef0053720e6583e80ca5116 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file batch_norm.cu
* \brief CUDA Batch Normalization code
* \author Chris Olivier, Bing Xu
* Adapted from Torch
*/
#include <cuda_runtime_api.h>
#include <algorithm>
#include "batch_norm-inl.h"
#define WRITE_DATA_FLAG 1
#define WRITE_GAMMA_FLAG 2
#define WRITE_BETA_FLAG 4
#define FIX_GAMMA_FLAG 8
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#include "./cudnn_batch_norm-inl.h"
#endif
#include "../common/cuda_utils.h"
using namespace mxnet;
/*! \brief inverse standard deviation <-> variance */
#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$)))
#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))
namespace mxnet {
namespace op {
namespace batchnorm {
namespace cuda {
static const unsigned WARP_SIZE = 32;
// The maximum number of threads in a block
static const unsigned MAX_BLOCK_SIZE = 512U;
template<typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
};
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static unsigned getNumThreads(int nElem, const bool smaller) {
unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
const int maxi = smaller ? 4 : 5;
for (int i = 0; i != maxi; ++i) {
if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
return threadSizes[i];
}
}
return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
}
// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
template<typename DType, typename AccReal>
struct Float2 {
AccReal v1, v2;
__device__ Float2() {}
__device__ Float2(DType v1, DType v2)
: v1(ScalarConvert<DType, AccReal>::to(v1))
, v2(ScalarConvert<DType, AccReal>::to(v2)) {}
__device__ Float2(DType v)
: v1(ScalarConvert<DType, AccReal>::to(v))
, v2(ScalarConvert<DType, AccReal>::to(v)) {}
__device__ Float2(int v)
: v1(ScalarConvert<int, AccReal>::to(v))
, v2(ScalarConvert<int, AccReal>::to(v)) {}
__device__ Float2 &operator+=(const Float2 &a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
template<typename DType, typename AccReal, typename DeviceTensor>
struct SumOp {
__device__ SumOp(const DeviceTensor t) : tensor(t) {}
__device__ __forceinline__ AccReal operator()(int batch, int plane, int n) {
return ScalarConvert<DType, AccReal>::to(tensor.get_ref(batch, plane, n));
}
const DeviceTensor tensor;
};
template<typename DType, typename AccReal, typename DeviceTensor>
struct VarOp {
__device__ VarOp(AccReal m, const DeviceTensor t)
: mean(m)
, tensor(t) {
}
__device__ __forceinline__ AccReal operator()(int batch, int plane, int n) {
DType val = tensor.get_ref(batch, plane, n);
return (val - mean) * (val - mean);
}
const AccReal mean;
const DeviceTensor tensor;
};
template<typename DType, typename AccReal, typename DeviceTensor>
struct GradOp {
__device__ GradOp(AccReal m, const DeviceTensor i, const DeviceTensor g)
: mean(m), input(i), gradOutput(g) {}
__device__ __forceinline__ Float2<DType, AccReal> operator()(int batch, int plane, int n) {
const DType g = gradOutput.get_ref(batch, plane, n);
const DType c = ScalarConvert<AccReal, DType>::to(input.get_ref(batch, plane, n) - mean);
return Float2<DType, AccReal>(g, g * c);
}
const AccReal mean;
const DeviceTensor input;
const DeviceTensor gradOutput;
};
// Sum across all threads within a warp
template<typename T>
static __device__ __forceinline__ T warpSum(T val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
val += __shfl_xor(val, 1 << i, WARP_SIZE);
}
#else
__shared__ T values[MAX_BLOCK_SIZE];
values[threadIdx.x] = val;
__threadfence_block();
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
for (int i = 1; i < WARP_SIZE; i++) {
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
}
#endif
return val;
}
template<typename DType, typename AccReal>
static __device__ __forceinline__ Float2<DType, AccReal> warpSum(Float2<DType, AccReal> value) {
value.v1 = warpSum(value.v1);
value.v2 = warpSum(value.v2);
return value;
}
// Sum across (batch, x/y/z) applying Op() pointwise
template<typename T, typename Op, typename DeviceTensor>
static __device__ T reduce(Op op, DeviceTensor tensor, int plane) {
T sum = (T) 0;
for (int batch = 0; batch < tensor.OuterSize(); ++batch) {
for (int x = threadIdx.x; x < tensor.InnerSize(); x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// sum over NumThreads within a warp
sum = warpSum(sum);
// 'transpose', and reduce within warp again
__shared__ T shared[32];
__syncthreads();
if (threadIdx.x % WARP_SIZE == 0) {
shared[threadIdx.x / WARP_SIZE] = sum;
}
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
// zero out the other entries in shared
shared[threadIdx.x] = (T) 0;
}
__syncthreads();
if (threadIdx.x / WARP_SIZE == 0) {
sum = warpSum(shared[threadIdx.x]);
if (threadIdx.x == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
template <typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
__global__ void BatchNormalizationUpdateOutputInferenceKernel(
DeviceTensor input,
DeviceTensor output,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
DeviceTensor1 saveMean,
DeviceTensor1 saveInvStd,
DeviceTensor1 weight,
DeviceTensor1 bias,
const DType epsilon,
const uint32_t flags) {
int plane = blockIdx.x;
AccReal invstd = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
AccReal mean = ScalarConvert<DType, AccReal>::to(runningMean[plane]);
AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
if (threadIdx.x == 0) {
saveMean[plane] = runningMean[plane];
saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
}
// Write normalized and update the output
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
const DType inp = input.get_ref(batch, plane, x);
output.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to(gamma * (inp - mean) * invstd + beta);
}
}
}
template<typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
__global__ void BatchNormalizationUpdateOutputKernel(
DeviceTensor input,
DeviceTensor output,
DeviceTensor1 weight,
DeviceTensor1 bias,
const AccReal epsilon,
const AccReal momentum,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
DeviceTensor1 saveMean,
DeviceTensor1 saveInvStd,
const uint32_t flags) {
const int plane = blockIdx.x;
const int N = input.OuterSize() * input.InnerSize();
const AccReal norm = AccReal(1) / N;
// Compute the mean and variance across (batch, x/y/z)
const AccReal mean = reduce<AccReal>(
SumOp<DType, AccReal, DeviceTensor>(input), input, plane) * norm;
__syncthreads();
const AccReal varN = reduce<AccReal>(VarOp<DType, AccReal, DeviceTensor>(mean, input),
input, plane);
AccReal invStd = 0;
if (varN != AccReal(0) || epsilon != AccReal(0)) {
invStd = AccReal(1.0) / sqrt(varN * norm + epsilon);
}
// Save the mean, variance, and moving averages
if (threadIdx.x == 0) {
// For one item (0th) per plane (channel), write the per-channel data (ie mean, variance, etc)
// Momentum based writeback
saveMean[plane] = ScalarConvert<AccReal, DType>::to(mean);
saveInvStd[plane] = invStd;
if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
}
// Write normalized and update the output
const AccReal gamma = weight.numElements() > 0
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
const DType inp = input.get_ref(batch, plane, x);
output.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to(gamma * (inp - mean) * invStd + beta);
}
}
}
template<typename DeviceTensor1>
struct CUDATensors {
DeviceTensor1 gradWeight;
DeviceTensor1 gradBias;
DeviceTensor1 weight;
DeviceTensor1 runningMean;
DeviceTensor1 runningVar;
DeviceTensor1 saveMean;
DeviceTensor1 saveInvStd;
};
template<typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
static __global__ void BatchNormalizationBackwardKernel(
const DeviceTensor input,
const DeviceTensor gradOutput,
DeviceTensor gradInput,
CUDATensors<DeviceTensor1> tensors,
const uint32_t flags,
const AccReal momentum,
const double eps) {
int plane = blockIdx.x;
int N = gradOutput.OuterSize() * gradOutput.InnerSize();
const bool is_train_and_not_global_stats =
(flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
AccReal mean, invstd;
if (is_train_and_not_global_stats) {
mean = ScalarConvert<DType, AccReal>::to(tensors.saveMean[plane]);
invstd = tensors.saveInvStd[plane];
} else {
mean = ScalarConvert<DType, AccReal>::to(tensors.runningMean[plane]);
invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps);
}
const AccReal weightVal = tensors.weight.numElements() > 0 ?
ScalarConvert<DType, AccReal>::to(tensors.weight[plane]) : AccReal(1);
const AccReal norm = AccReal(1) / N;
// Compute two values across (batch, x/y/z) in one pass:
// 1. Sum(gradOutput)
// 2. DotProduct(input - mean, gradOutput)
GradOp<DType, AccReal, DeviceTensor> g(mean, input, gradOutput);
Float2< DType, AccReal > res = reduce < Float2 < DType, AccReal >,
GradOp< DType, AccReal, DeviceTensor >, DeviceTensor > (g, gradOutput, plane);
const AccReal gradOutputSum = res.v1;
const AccReal dotP = res.v2;
const AccReal gradMean = gradOutputSum * norm;
const AccReal projScale = dotP * norm * invstd * invstd;
const AccReal gradScale = invstd * weightVal;
if (threadIdx.x == 0 && is_train_and_not_global_stats) {
const AccReal localVariance = INVSTD_TO_VARIANCE(tensors.saveInvStd[plane], eps);
const AccReal localMean = tensors.saveMean[plane];
// update running averages
tensors.runningMean[plane] = tensors.runningMean[plane]
* momentum + localMean * (AccReal(1) - momentum);
tensors.runningVar[plane] = tensors.runningVar[plane]
* momentum + localVariance * (AccReal(1) - momentum);
}
if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) {
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
}
}
}
}
if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
} else {
tensors.gradWeight[plane] = DType(0);
}
}
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
}
}
template<typename DType, int Dim>
struct DeviceTensor {
public:
inline DeviceTensor() {}
inline DeviceTensor(DType *p, const int *size)
: dptr_(p) {
for (int i = 0; i < Dim; ++i) {
size_[i] = size ? size[i] : 0;
}
}
MSHADOW_XINLINE unsigned getSize(const int i) const {
return size_[i];
}
MSHADOW_XINLINE int numElements() const {
int n = 1;
for (int i = 0; i < Dim; ++i) {
n *= size_[i];
}
return n;
}
MSHADOW_XINLINE DType &operator()(const size_t batch,
const size_t plane,
const size_t x) const {
int offset = 0;
offset *= size_[0];
offset += batch;
offset *= size_[1];
offset += plane;
offset *= size_[2];
offset += x;
return *(const_cast<DType *>(dptr_ + offset));
}
MSHADOW_XINLINE DType &operator[](const size_t x) const {
return *(dptr_ + x);
}
MSHADOW_XINLINE size_t InnerSize() const {
size_t sz = 1;
for (size_t i = 2; i < Dim; ++i) {
sz *= size_[i];
}
return sz;
}
MSHADOW_XINLINE size_t ChannelCount() const {
return size_[1];
}
DType *dptr_;
int size_[Dim];
};
template<typename DType, int Dim>
static DeviceTensor<DType, Dim> devicetensor(const TBlob &blob) {
DType *data = blob.dptr<DType>();
const int inDim = blob.shape_.ndim();
if (inDim == Dim) {
DeviceTensor<DType, Dim> tensor(data, nullptr);
for (int i = 0; i < Dim; ++i) {
tensor.size_[i] = blob.size(i);
}
return tensor;
}
// View in which the last dimensions are collapsed or expanded as needed
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = blob.size(i);
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= blob.size(i);
}
}
return DeviceTensor<DType, Dim>(data, &size[0]);
}
#define DeviceTensor1 DeviceTensor<AccReal, 1>
using namespace mxnet::op;
template<typename DType, typename AccReal>
static void BatchNormalizationUpdateOutput(mshadow::Stream<gpu> *s,
const OpContext &ctx,
const BatchNormParam& param,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states,
const uint32_t flags,
double momentum,
double eps) {
batchnorm::BNTensor3<DType> input = batchnorm::BNTensor3<DType>(
in_data[batchnorm::kData], param.axis);
batchnorm::BNTensor3<DType> output = batchnorm::BNTensor3<DType>(
out_data[batchnorm::kOut], param.axis);
DeviceTensor1 weight = devicetensor<AccReal, 1>(in_data[batchnorm::kGamma]);
DeviceTensor1 bias = devicetensor<AccReal, 1>(in_data[batchnorm::kBeta]);
DeviceTensor1 runningMean = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingMean]);
DeviceTensor1 runningVar = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingVar]);
DeviceTensor1 saveMean = devicetensor<AccReal, 1>(out_data[batchnorm::kMean]);
DeviceTensor1 saveInvStd = devicetensor<AccReal, 1>(out_data[batchnorm::kVar]);
DCHECK_GT(weight.numElements(), 0);
if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) {
dim3 blocks(input.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false));
BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, DeviceTensor1,
batchnorm::BNTensor3<DType>>
<<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
input, output, runningMean, runningVar, saveMean,
saveInvStd, weight, bias, eps, flags);
} else {
dim3 blocks(input.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false));
BatchNormalizationUpdateOutputKernel<DType, AccReal, DeviceTensor1,
batchnorm::BNTensor3<DType>>
<< < blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >> > (
input, output, weight, bias, eps, momentum, runningMean, runningVar,
saveMean, saveInvStd, flags);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationUpdateOutput);
}
template<typename DType, typename AccReal>
static void BatchNormalizationBackward(mshadow::Stream<gpu> *s,
const OpContext &ctx,
const BatchNormParam& param,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states,
const uint32_t flags,
double momentum,
double eps) {
batchnorm::BNTensor3<DType> input = batchnorm::BNTensor3<DType>(
in_data[batchnorm::kData], param.axis);
batchnorm::BNTensor3<DType>gradOutput = batchnorm::BNTensor3<DType>(
out_grad[batchnorm::kOut], param.axis);
batchnorm::BNTensor3<DType>gradInput = batchnorm::BNTensor3<DType>(
in_grad[batchnorm::kData], param.axis);
CUDATensors<DeviceTensor1> tensors;
tensors.gradWeight = devicetensor<AccReal, 1>(in_grad[batchnorm::kGamma]);
tensors.gradBias = devicetensor<AccReal, 1>(in_grad[batchnorm::kBeta]);
tensors.weight = devicetensor<AccReal, 1>(in_data[batchnorm::kGamma]);
tensors.runningMean = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingMean]);
tensors.runningVar = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingVar]);
tensors.saveMean = devicetensor<AccReal, 1>(out_data[batchnorm::kMean]);
tensors.saveInvStd = devicetensor<AccReal, 1>(out_data[batchnorm::kVar]);
DCHECK_GT(tensors.weight.numElements(), 0);
#ifdef NDEBUG
constexpr bool SMALLER_THREADS = false;
#else
constexpr bool SMALLER_THREADS = true;
#endif
dim3 blocks(gradOutput.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS));
BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
<<< blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s) >>> (
input, gradOutput, gradInput, tensors, flags, momentum, eps);
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward);
}
} // namespace cuda
} // namespace batchnorm
template<typename xpu, typename DType, typename AccReal>
static inline uint32_t SetupFlags(const OpContext &ctx,
const BatchNormParam& params,
const std::vector<OpReqType> &req) {
uint32_t flags = 0;
flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (BatchNormOp<xpu, DType, AccReal>::IsWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
}
if (BatchNormOp<xpu, DType, AccReal>::IsWriting(req[batchnorm::kGamma])) {
flags |= WRITE_GAMMA_FLAG;
}
if (BatchNormOp<xpu, DType, AccReal>::IsWriting(req[batchnorm::kBeta])) {
flags |= WRITE_BETA_FLAG;
}
return flags;
}
/*! \brief Forward batch-norm pass on GPU */
template<typename xpu, typename DType, typename AccReal>
void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<gpu> *stream,
const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) {
batchnorm::cuda::BatchNormalizationUpdateOutput<DType, AccReal>(
stream,
ctx,
param_,
in_data,
out_data,
aux_states,
SetupFlags<xpu, DType, AccReal>(ctx, param_, req),
param_.momentum,
param_.eps);
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoForward_gpu);
}
/*! \brief Backward batch-norm pass on GPU */
template<typename xpu, typename DType, typename AccReal>
void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<gpu> *stream,
const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
batchnorm::cuda::BatchNormalizationBackward<DType, AccReal>(
stream,
ctx,
param_,
out_grad,
in_data,
out_data,
in_grad,
aux_states,
SetupFlags<xpu, DType, AccReal>(ctx, param_, req),
param_.momentum,
param_.eps);
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}
/*! \brief Create GPU operator for batch normalization */
template<>
Operator *CreateOp<gpu>(BatchNormParam param, const int dtype, const TShape& shape) {
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
Operator *op = NULL;
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new CuDNNBatchNormOp<DType>(param);
})
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
op = new BatchNormOp<gpu, DType, AccReal>(param);
})
}
#else
MSHADOW_REAL_TYPE_SWITCH_EX(dtype,
DType,
AccReal,
{ op = new BatchNormOp<gpu, DType, AccReal>(param); });
#endif
return op;
}
} // namespace op
} // namespace mxnet