blob: 6ff71aae18bd644f1865d8d84857a63675d90f04 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file batch_norm.cu
* \brief CUDA Batch Normalization code
* \author Chris Olivier, Bing Xu, Da Zheng
* Adapted from Torch
*/
#include <cuda_runtime_api.h>
#include <algorithm>
#include "batch_norm-inl.h"
#include "../../common/cuda/utils.h"
#define WRITE_DATA_FLAG 1
#define WRITE_GAMMA_FLAG 2
#define WRITE_BETA_FLAG 4
#define FIX_GAMMA_FLAG 8
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32
#define ADDTO_DATA_FLAG (1 << 6)
#define ADDTO_GAMMA_FLAG (1 << 7)
#define ADDTO_BETA_FLAG (1 << 8)
#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm.h"
#endif
#include "../../../include/mxnet/tensor_blob.h"
using namespace mxnet;
namespace {
/*! \brief inverse standard deviation <-> variance */
template <typename DType, typename AccReal>
MSHADOW_XINLINE AccReal variance_to_invstd(DType var, AccReal eps) {
return rsqrtf(static_cast<AccReal>(var) + eps);
}
template <>
MSHADOW_XINLINE double variance_to_invstd(double var, double eps) {
return rsqrt(var + eps);
}
template <typename AccReal>
MSHADOW_XINLINE AccReal invstd_to_variance(AccReal invstd, AccReal eps) {
return static_cast<AccReal>(1.0) / (invstd * invstd) - eps;
}
template <>
MSHADOW_XINLINE double invstd_to_variance(double invstd, double eps) {
return 1.0 / (invstd * invstd) - eps;
}
} // namespace
namespace mxnet {
namespace op {
namespace batchnorm {
namespace cuda {
static const unsigned WARP_SIZE = 32;
// The maximum number of threads in a block
static const unsigned MAX_BLOCK_SIZE = 512U;
template <typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ __forceinline__ Out to(const In v) {
return (Out)v;
}
};
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static unsigned getNumThreads(int nElem) {
unsigned threadSizes[4] = {32, 64, 128, 256};
for (int i = 0; i != 4; ++i) {
if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}
// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
template <typename DType, typename AccReal>
struct Float2 {
AccReal v1, v2;
__device__ Float2() {}
__device__ Float2(DType v1, DType v2)
: v1(ScalarConvert<DType, AccReal>::to(v1)), v2(ScalarConvert<DType, AccReal>::to(v2)) {}
__device__ Float2(DType v)
: v1(ScalarConvert<DType, AccReal>::to(v)), v2(ScalarConvert<DType, AccReal>::to(v)) {}
__device__ Float2(int v)
: v1(ScalarConvert<int, AccReal>::to(v)), v2(ScalarConvert<int, AccReal>::to(v)) {}
__device__ Float2& operator+=(const Float2& a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
template <typename DType, typename AccReal, typename DeviceTensor>
struct SumOp {
__device__ SumOp(const DeviceTensor t) : tensor(t) {}
__device__ __forceinline__ AccReal operator()(int batch, int plane, int n) {
return ScalarConvert<DType, AccReal>::to(tensor.get_ref(batch, plane, n));
}
const DeviceTensor tensor;
};
template <typename DType, typename AccReal, typename DeviceTensor>
struct VarOp {
__device__ VarOp(AccReal m, const DeviceTensor t) : mean(m), tensor(t) {}
__device__ __forceinline__ AccReal operator()(int batch, int plane, int n) {
DType val = tensor.get_ref(batch, plane, n);
return (val - mean) * (val - mean);
}
const AccReal mean;
const DeviceTensor tensor;
};
template <typename DType, typename AccReal, typename DeviceTensor>
struct GradOp {
__device__ GradOp(AccReal m, const DeviceTensor i, const DeviceTensor g)
: mean(m), input(i), gradOutput(g) {}
__device__ __forceinline__ Float2<DType, AccReal> operator()(int batch, int plane, int n) {
const DType g = gradOutput.get_ref(batch, plane, n);
const DType c = ScalarConvert<AccReal, DType>::to(input.get_ref(batch, plane, n) - mean);
return Float2<DType, AccReal>(g, g * c);
}
const AccReal mean;
const DeviceTensor input;
const DeviceTensor gradOutput;
};
#if CUDA_VERSION >= 9000
#define FULLMASK 0xFFFFFFFF
#define __shfl_xor(...) __shfl_xor_sync(FULLMASK, __VA_ARGS__)
#endif
// Sum across all threads within a warp
template <typename T>
static __device__ __forceinline__ T warpSum(T val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
val += __shfl_xor(val, 1 << i, WARP_SIZE);
}
#else
__shared__ T values[MAX_BLOCK_SIZE];
values[threadIdx.x] = val;
__threadfence_block();
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
for (int i = 1; i < WARP_SIZE; i++) {
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
}
#endif
return val;
}
template <typename DType, typename AccReal>
static __device__ __forceinline__ Float2<DType, AccReal> warpSum(Float2<DType, AccReal> value) {
value.v1 = warpSum(value.v1);
value.v2 = warpSum(value.v2);
return value;
}
// Sum across (batch, x/y/z) applying Op() pointwise
template <typename T, typename Op, typename DeviceTensor>
static __device__ T reduce(Op op, DeviceTensor tensor, int plane) {
T sum = (T)0;
for (int batch = 0; batch < tensor.OuterSize(); ++batch) {
for (int x = threadIdx.x; x < tensor.InnerSize(); x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// sum over NumThreads within a warp
sum = warpSum(sum);
// 'transpose', and reduce within warp again
__shared__ T shared[32];
__syncthreads();
if (threadIdx.x % WARP_SIZE == 0) {
shared[threadIdx.x / WARP_SIZE] = sum;
}
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
// zero out the other entries in shared
shared[threadIdx.x] = (T)0;
}
__syncthreads();
if (threadIdx.x / WARP_SIZE == 0) {
sum = warpSum(shared[threadIdx.x]);
if (threadIdx.x == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
namespace {
constexpr int inference_forward_threads = 512;
constexpr int shmem_elements = 1536;
} // namespace
template <typename DType, typename AType, typename LType, bool small_num_channels>
__launch_bounds__(inference_forward_threads) __global__
void BatchNormalizationUpdateOutputInferenceKernel(const DType* input,
DType* output,
const index_t size,
const index_t outer_size,
const index_t num_channels,
const index_t inner_size,
const AType* runningMean,
const AType* runningVar,
AType* saveMean,
AType* saveInvStd,
AType* weight,
AType* bias,
const AType epsilon,
const uint32_t flags) {
constexpr int nvec = sizeof(LType) / sizeof(DType);
__shared__ AType saved_invstd[shmem_elements];
__shared__ AType saved_mean[shmem_elements];
__shared__ AType saved_weight[shmem_elements];
__shared__ AType saved_bias[shmem_elements];
union vectorized_loader {
LType aligned;
DType separate[nvec]; // NOLINT(*)
__device__ inline vectorized_loader() {}
__device__ inline ~vectorized_loader() {}
} scratch;
if (small_num_channels) {
for (int i = threadIdx.x; i < num_channels; i += blockDim.x) {
saved_invstd[i] = variance_to_invstd(runningVar[i], epsilon);
saved_mean[i] = runningMean[i];
saved_weight[i] = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) ? weight[i] : 1;
saved_bias[i] = (bias != nullptr) ? bias[i] : 0;
}
__syncthreads();
}
const index_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const index_t stride = blockDim.x * gridDim.x;
const LType* input_aligned = reinterpret_cast<const LType*>(input);
LType* output_aligned = reinterpret_cast<LType*>(output);
for (index_t i = tid; i < size / nvec; i += stride) {
scratch.aligned = input_aligned[i];
const index_t my_channel_base = (nvec * i) % (inner_size * num_channels);
#pragma unroll
for (int j = 0; j < nvec; ++j) {
index_t my_channel = (my_channel_base + j) / inner_size;
if (my_channel >= num_channels)
my_channel = my_channel % num_channels;
AType current_input = static_cast<AType>(scratch.separate[j]);
AType invstd = small_num_channels ? saved_invstd[my_channel] :
variance_to_invstd(runningVar[my_channel], epsilon);
AType mean = small_num_channels ? saved_mean[my_channel] : runningMean[my_channel];
AType gamma =
small_num_channels ?
saved_weight[my_channel] :
((weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) ? weight[my_channel] : 1);
AType beta =
small_num_channels ? saved_bias[my_channel] : ((bias != nullptr) ? bias[my_channel] : 0);
current_input = gamma * (current_input - mean) * invstd + beta;
scratch.separate[j] = current_input;
}
output_aligned[i] = scratch.aligned;
if (i < num_channels) {
saveMean[i] = runningMean[i];
saveInvStd[i] = variance_to_invstd(runningVar[i], epsilon);
if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 && weight != nullptr) {
weight[i] = 1;
}
}
}
}
template <typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
__global__ void BatchNormalizationUpdateOutputKernel(DeviceTensor input,
DeviceTensor output,
DeviceTensor1 weight,
DeviceTensor1 bias,
const AccReal epsilon,
const AccReal momentum,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
DeviceTensor1 saveMean,
DeviceTensor1 saveInvStd,
const uint32_t flags) {
const int plane = blockIdx.x;
const int N = input.OuterSize() * input.InnerSize();
const AccReal norm = AccReal(1) / N;
// Compute the mean and variance across (batch, x/y/z)
const AccReal mean =
reduce<AccReal>(SumOp<DType, AccReal, DeviceTensor>(input), input, plane) * norm;
__syncthreads();
const AccReal varN =
reduce<AccReal>(VarOp<DType, AccReal, DeviceTensor>(mean, input), input, plane);
AccReal invStd = 0;
if (varN != AccReal(0) || epsilon != AccReal(0)) {
invStd = AccReal(1.0) / sqrt(varN * norm + epsilon);
}
// Save the mean, variance, and moving averages
if (threadIdx.x == 0) {
// For one item (0th) per plane (channel), write the per-channel data (ie mean, variance, etc)
// Momentum based writeback
saveMean[plane] = ScalarConvert<AccReal, DType>::to(mean);
saveInvStd[plane] = invStd;
if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 &&
weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
}
// Write normalized and update the output
const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0) ?
ScalarConvert<DType, AccReal>::to(weight[plane]) :
ScalarConvert<int, AccReal>::to(1);
const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane]) :
ScalarConvert<int, AccReal>::to(0);
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
const DType inp = input.get_ref(batch, plane, x);
output.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to(gamma * (inp - mean) * invStd + beta);
}
}
}
template <typename DeviceTensor1>
struct CUDATensors {
DeviceTensor1 gradWeight;
DeviceTensor1 gradBias;
DeviceTensor1 weight;
DeviceTensor1 runningMean;
DeviceTensor1 runningVar;
DeviceTensor1 saveMean;
DeviceTensor1 saveInvStd;
};
namespace {
inline int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
} // namespace
template <int NTHREADS, typename DType, typename AType, typename LType>
__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1(const DType* input,
const DType* gradOutput,
AType* temp_space,
DType* gradInput,
const AType* weight,
const AType* runningMean,
const AType* runningVar,
const index_t outer,
const index_t num_channels,
const AType eps,
const uint32_t flags) {
using mxnet::common::cuda::warp_size;
constexpr int num_warps = NTHREADS / warp_size;
constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
const size_t stride = num_channels / nvec;
union vectorized_loader {
LType aligned;
DType separate[nvec]; // NOLINT(*)
__device__ inline vectorized_loader() {}
__device__ inline ~vectorized_loader() {}
};
vectorized_loader vec_input, vec_gradOutput;
__shared__ AType scratch[NTHREADS * 2 * nvec];
AType* my_values_gamma = &(scratch[threadIdx.x * nvec]);
AType* my_values_beta = &(scratch[(NTHREADS + threadIdx.x) * nvec]);
AType sum_gamma[nvec]; // NOLINT(*)
AType sum_beta[nvec]; // NOLINT(*)
#pragma unroll
for (int i = 0; i < nvec; ++i) {
sum_gamma[i] = 0;
sum_beta[i] = 0;
}
const size_t offset = blockIdx.x * warp_size;
const int my_warp = threadIdx.x / warp_size;
const int thread_idx_in_warp = threadIdx.x % warp_size;
AType invstd[nvec]; // NOLINT(*)
AType mean[nvec]; // NOLINT(*)
AType gamma[nvec]; // NOLINT(*)
size_t channel_offset = (offset + thread_idx_in_warp) * nvec;
if (channel_offset < num_channels) {
#pragma unroll
for (int i = 0; i < nvec; ++i) {
invstd[i] = variance_to_invstd(runningVar[channel_offset + i], eps);
mean[i] = runningMean[channel_offset + i];
gamma[i] = weight != nullptr ? weight[channel_offset + i] : 1;
}
}
const LType* aligned_gradOutput = reinterpret_cast<const LType*>(gradOutput);
const LType* aligned_input = reinterpret_cast<const LType*>(input);
LType* gradInput_aligned = reinterpret_cast<LType*>(gradInput);
const int rows_per_block = (outer + gridDim.y - 1) / gridDim.y;
const size_t start_row = my_warp + rows_per_block * blockIdx.y;
const size_t end_row = min(outer, static_cast<index_t>(rows_per_block * (blockIdx.y + 1)));
if (offset + thread_idx_in_warp < stride) {
for (size_t i = start_row; i < end_row; i += num_warps) {
const index_t idx = i * stride + offset + thread_idx_in_warp;
vec_gradOutput.aligned = aligned_gradOutput[idx];
vec_input.aligned = aligned_input[idx];
#pragma unroll
for (int j = 0; j < nvec; ++j) {
sum_beta[j] += static_cast<AType>(vec_gradOutput.separate[j]);
sum_gamma[j] += static_cast<AType>(vec_gradOutput.separate[j]) *
(static_cast<AType>(vec_input.separate[j]) - mean[j]);
}
if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) {
// Gradient to input
#pragma unroll
for (int j = 0; j < nvec; ++j) {
vec_gradOutput.separate[j] *= invstd[j] * gamma[j];
}
if (flags & ADDTO_DATA_FLAG) {
vec_input.aligned = gradInput_aligned[idx];
#pragma unroll
for (int j = 0; j < nvec; ++j) {
vec_gradOutput.separate[j] += vec_input.separate[j];
}
}
gradInput_aligned[idx] = vec_gradOutput.aligned;
}
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < nvec; ++i) {
my_values_gamma[i] = sum_gamma[i];
my_values_beta[i] = sum_beta[i];
}
__syncthreads();
for (int i = num_warps / 2; i > 0; i /= 2) {
if (my_warp < i) {
const int shared_offset = nvec * i * warp_size;
#pragma unroll
for (int j = 0; j < nvec; ++j) {
my_values_gamma[j] += my_values_gamma[j + shared_offset];
my_values_beta[j] += my_values_beta[j + shared_offset];
}
}
__syncthreads();
}
if (threadIdx.x < min(warp_size * nvec, static_cast<int>(num_channels - nvec * offset))) {
const size_t offset_out = nvec * offset + blockIdx.y * num_channels;
const size_t offset_beta = gridDim.y * num_channels;
temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x];
temp_space[offset_beta + offset_out + threadIdx.x] = scratch[NTHREADS * nvec + threadIdx.x];
}
}
template <typename AType>
__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase2(const AType* temp_space,
const AType* runningVar,
AType* out_gamma,
AType* out_beta,
int lead_dim,
int n_blocks,
AType epsilon,
uint32_t flags) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < lead_dim) {
AType sum_gamma = 0;
AType sum_beta = 0;
for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
sum_gamma += temp_space[i];
sum_beta += temp_space[i + lead_dim * n_blocks];
}
if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
const AType invstd = variance_to_invstd(runningVar[tid], epsilon);
if (flags & WRITE_GAMMA_FLAG) {
out_gamma[tid] = sum_gamma * invstd;
} else {
out_gamma[tid] += sum_gamma * invstd;
}
} else {
if (flags & WRITE_GAMMA_FLAG) {
out_gamma[tid] = 0;
}
}
}
if (flags & WRITE_BETA_FLAG) {
out_beta[tid] = sum_beta;
} else if (flags & ADDTO_BETA_FLAG) {
out_beta[tid] += sum_beta;
}
}
}
template <int NTHREADS, typename DType, typename AType, typename LType>
__global__ void FrozenBatchNormalizationBackwardKernel(const DType* input,
const DType* gradOutput,
DType* gradInput,
AType* gradWeight,
AType* gradBias,
const AType* weight,
const AType* runningMean,
const AType* runningVar,
const index_t outer,
const index_t inner,
const index_t num_channels,
const index_t NHW_div_nvec,
const AType eps,
const uint32_t flags) {
const index_t my_channel = blockIdx.x;
const AType invstd = variance_to_invstd(runningVar[my_channel], eps);
const AType mean = runningMean[my_channel];
const AType gamma = weight != nullptr ? weight[my_channel] : 1;
constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
union vectorized_loader {
LType aligned;
DType separate[nvec]; // NOLINT(*)
__device__ inline vectorized_loader() {}
__device__ inline ~vectorized_loader() {}
};
vectorized_loader vec_input, vec_gradOutput;
const LType* input_aligned = reinterpret_cast<const LType*>(input);
const LType* gradOutput_aligned = reinterpret_cast<const LType*>(gradOutput);
LType* gradInput_aligned = reinterpret_cast<LType*>(gradInput);
const index_t inner_div_nvec = inner / nvec;
AType sum_gamma = 0;
AType sum_beta = 0;
for (index_t i = threadIdx.x; i < NHW_div_nvec; i += blockDim.x) {
const index_t inner_idx = i % inner_div_nvec;
const index_t outer_idx = i / inner_div_nvec;
const index_t idx = inner_idx + (my_channel + outer_idx * num_channels) * inner_div_nvec;
vec_gradOutput.aligned = gradOutput_aligned[idx];
vec_input.aligned = input_aligned[idx];
#pragma unroll
for (int j = 0; j < nvec; ++j) {
sum_beta += static_cast<AType>(vec_gradOutput.separate[j]);
sum_gamma += static_cast<AType>(vec_gradOutput.separate[j]) *
(static_cast<AType>(vec_input.separate[j]) - mean);
}
if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) {
// Gradient to input
#pragma unroll
for (int j = 0; j < nvec; ++j) {
vec_gradOutput.separate[j] *= invstd * gamma;
}
if (flags & ADDTO_DATA_FLAG) {
vec_input.aligned = gradInput_aligned[idx];
#pragma unroll
for (int j = 0; j < nvec; ++j) {
vec_gradOutput.separate[j] += vec_input.separate[j];
}
}
gradInput_aligned[idx] = vec_gradOutput.aligned;
}
}
sum_gamma =
common::cuda::reduce<NTHREADS, false>(sum_gamma, [](AType a, AType b) { return a + b; });
sum_beta =
common::cuda::reduce<NTHREADS, false>(sum_beta, [](AType a, AType b) { return a + b; });
if (threadIdx.x == 0) {
if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
if (flags & WRITE_GAMMA_FLAG) {
gradWeight[my_channel] = sum_gamma * invstd;
} else {
gradWeight[my_channel] += sum_gamma * invstd;
}
} else {
if (flags & WRITE_GAMMA_FLAG) {
gradWeight[my_channel] = 0;
}
}
}
if (flags & WRITE_BETA_FLAG) {
gradBias[my_channel] = sum_beta;
} else if (flags & ADDTO_BETA_FLAG) {
gradBias[my_channel] += sum_beta;
}
}
}
template <typename DType, typename AccReal, typename DeviceTensor1, typename DeviceTensor>
static __global__ void BatchNormalizationBackwardKernel(const DeviceTensor input,
const DeviceTensor gradOutput,
DeviceTensor gradInput,
CUDATensors<DeviceTensor1> tensors,
const uint32_t flags,
const AccReal momentum,
const AccReal eps) {
int plane = blockIdx.x;
int N = gradOutput.OuterSize() * gradOutput.InnerSize();
AccReal mean, invstd;
mean = ScalarConvert<DType, AccReal>::to(tensors.saveMean[plane]);
invstd = tensors.saveInvStd[plane];
const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ?
ScalarConvert<DType, AccReal>::to(tensors.weight[plane]) :
AccReal(1);
const AccReal norm = AccReal(1) / N;
// Compute two values across (batch, x/y/z) in one pass:
// 1. Sum(gradOutput)
// 2. DotProduct(input - mean, gradOutput)
GradOp<DType, AccReal, DeviceTensor> g(mean, input, gradOutput);
Float2<DType, AccReal> res =
reduce<Float2<DType, AccReal>, GradOp<DType, AccReal, DeviceTensor>, DeviceTensor>(
g, gradOutput, plane);
const AccReal gradOutputSum = res.v1;
const AccReal dotP = res.v2;
const AccReal gradMean = gradOutputSum * norm;
const AccReal projScale = dotP * norm * invstd * invstd;
const AccReal gradScale = invstd * weightVal;
if (threadIdx.x == 0) {
const AccReal localVariance = invstd_to_variance(tensors.saveInvStd[plane], eps);
const AccReal localMean = tensors.saveMean[plane];
// update running averages
tensors.runningMean[plane] =
tensors.runningMean[plane] * momentum + localMean * (AccReal(1) - momentum);
tensors.runningVar[plane] =
tensors.runningVar[plane] * momentum + localVariance * (AccReal(1) - momentum);
}
if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) {
const bool grad_write = flags & WRITE_DATA_FLAG;
if (grad_write) {
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
}
}
} else {
// grad addto
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) +=
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
}
}
}
}
if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 &&
(flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
if (flags & WRITE_GAMMA_FLAG)
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
else
tensors.gradWeight[plane] += ScalarConvert<AccReal, DType>::to(dotP * invstd);
} else {
tensors.gradWeight[plane] = DType(0);
}
}
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 &&
(flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) {
if (flags & WRITE_BETA_FLAG)
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
else
tensors.gradBias[plane] += ScalarConvert<AccReal, DType>::to(gradOutputSum);
}
}
template <typename DType, int Dim>
struct DeviceTensor {
public:
inline DeviceTensor() {}
inline DeviceTensor(DType* p, const int* size) : dptr_(p) {
for (int i = 0; i < Dim; ++i) {
size_[i] = size ? size[i] : 0;
}
}
MSHADOW_XINLINE unsigned getSize(const int i) const {
return size_[i];
}
MSHADOW_XINLINE int numElements() const {
int n = 1;
for (int i = 0; i < Dim; ++i) {
n *= size_[i];
}
return n;
}
MSHADOW_XINLINE DType& operator()(const size_t batch, const size_t plane, const size_t x) const {
int offset = 0;
offset *= size_[0];
offset += batch;
offset *= size_[1];
offset += plane;
offset *= size_[2];
offset += x;
return *(const_cast<DType*>(dptr_ + offset));
}
MSHADOW_XINLINE DType& operator[](const size_t x) const {
return *(dptr_ + x);
}
MSHADOW_XINLINE size_t InnerSize() const {
size_t sz = 1;
for (size_t i = 2; i < Dim; ++i) {
sz *= size_[i];
}
return sz;
}
MSHADOW_XINLINE size_t ChannelCount() const {
return size_[1];
}
DType* dptr_;
int size_[Dim];
};
template <typename DType, int Dim>
static DeviceTensor<DType, Dim> devicetensor(const TBlob& blob) {
CHECK_EQ(blob.type_flag_, mshadow::DataType<DType>::kFlag);
DType* data = blob.dptr<DType>();
const int inDim = blob.shape_.ndim();
if (inDim == Dim) {
DeviceTensor<DType, Dim> tensor(data, nullptr);
for (int i = 0; i < Dim; ++i) {
tensor.size_[i] = blob.size(i);
}
return tensor;
}
// View in which the last dimensions are collapsed or expanded as needed
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = blob.size(i);
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= blob.size(i);
}
}
return DeviceTensor<DType, Dim>(data, &size[0]);
}
#define DeviceTensor1 DeviceTensor<AccReal, 1>
using namespace mxnet::op;
template <typename DType, typename AccReal>
static void BatchNormalizationUpdateOutput(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const BatchNormParam& param,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data,
const std::vector<TBlob>& aux_states,
const uint32_t flags,
double momentum,
double eps) {
batchnorm::BNTensor3<DType> input =
batchnorm::BNTensor3<DType>(in_data[batchnorm::kData], param.axis);
batchnorm::BNTensor3<DType> output =
batchnorm::BNTensor3<DType>(out_data[batchnorm::kOut], param.axis);
DeviceTensor1 weight = devicetensor<AccReal, 1>(in_data[batchnorm::kGamma]);
DeviceTensor1 bias = devicetensor<AccReal, 1>(in_data[batchnorm::kBeta]);
DeviceTensor1 runningMean = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingMean]);
DeviceTensor1 runningVar = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingVar]);
DeviceTensor1 saveMean = devicetensor<AccReal, 1>(out_data[batchnorm::kMean]);
DeviceTensor1 saveInvStd = devicetensor<AccReal, 1>(out_data[batchnorm::kVar]);
DCHECK_GT(weight.numElements(), 0);
if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) {
AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr;
AccReal* gamma_ptr = weight.numElements() > 0 ? weight.dptr_ : nullptr;
int nvec = sizeof(double) / sizeof(DType);
index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount();
index_t aligned_size = ((size + nvec - 1) / nvec) * nvec;
index_t blocks =
std::min((size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads),
static_cast<index_t>(512));
if (input.ChannelCount() < shmem_elements) {
BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, double, true>
<<<blocks, inference_forward_threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
input.dptr_,
output.dptr_,
aligned_size,
input.OuterSize(),
input.ChannelCount(),
input.InnerSize(),
runningMean.dptr_,
runningVar.dptr_,
saveMean.dptr_,
saveInvStd.dptr_,
gamma_ptr,
bias_ptr,
eps,
flags);
} else {
BatchNormalizationUpdateOutputInferenceKernel<DType, AccReal, double, false>
<<<blocks, inference_forward_threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
input.dptr_,
output.dptr_,
aligned_size,
input.OuterSize(),
input.ChannelCount(),
input.InnerSize(),
runningMean.dptr_,
runningVar.dptr_,
saveMean.dptr_,
saveInvStd.dptr_,
gamma_ptr,
bias_ptr,
eps,
flags);
}
} else {
dim3 blocks(input.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize()));
BatchNormalizationUpdateOutputKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
<<<blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(input,
output,
weight,
bias,
eps,
momentum,
runningMean,
runningVar,
saveMean,
saveInvStd,
flags);
}
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationUpdateOutput);
}
template <typename DType, typename AccReal>
static void BatchNormalizationBackward(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const BatchNormParam& param,
const std::vector<TBlob>& out_grad,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data,
const std::vector<TBlob>& in_grad,
const std::vector<TBlob>& aux_states,
const uint32_t flags,
double momentum,
double eps) {
batchnorm::BNTensor3<DType> input =
batchnorm::BNTensor3<DType>(in_data[batchnorm::kData], param.axis);
batchnorm::BNTensor3<DType> gradOutput =
batchnorm::BNTensor3<DType>(out_grad[batchnorm::kOut], param.axis);
batchnorm::BNTensor3<DType> gradInput =
batchnorm::BNTensor3<DType>(in_grad[batchnorm::kData], param.axis);
CHECK_EQ(gradOutput.Size(), gradInput.Size());
CUDATensors<DeviceTensor1> tensors;
tensors.gradWeight = devicetensor<AccReal, 1>(in_grad[batchnorm::kGamma]);
tensors.gradBias = devicetensor<AccReal, 1>(in_grad[batchnorm::kBeta]);
tensors.weight = devicetensor<AccReal, 1>(in_data[batchnorm::kGamma]);
tensors.runningMean = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingMean]);
tensors.runningVar = devicetensor<AccReal, 1>(aux_states[batchnorm::kMovingVar]);
tensors.saveMean = devicetensor<AccReal, 1>(out_data[batchnorm::kMean]);
tensors.saveInvStd = devicetensor<AccReal, 1>(out_data[batchnorm::kVar]);
DCHECK_GT(tensors.weight.numElements(), 0);
const bool is_train_and_not_global_stats =
(flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;
if (is_train_and_not_global_stats) {
dim3 blocks(gradOutput.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
<<<blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
input, gradOutput, gradInput, tensors, flags, momentum, eps);
} else {
uint32_t flags_copy = flags;
if (gradInput.Size() <= 0) {
flags_copy = (flags_copy & ~WRITE_DATA_FLAG);
}
if (tensors.gradWeight.numElements() <= 0) {
flags_copy = (flags_copy & ~WRITE_GAMMA_FLAG);
}
if (tensors.gradBias.numElements() <= 0) {
flags_copy = (flags_copy & ~WRITE_BETA_FLAG);
}
AccReal* gamma = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ?
tensors.weight.dptr_ :
nullptr;
if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) {
const int C = gradOutput.ChannelCount();
int ltype = mxnet::common::cuda::get_load_type(C * sizeof(DType));
const int M = gradOutput.OuterSize();
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
const unsigned int blocks_x =
ceil_div(C * sizeof(DType), mxnet::common::cuda::warp_size * sizeof(LType));
const unsigned int preferred_number_of_blocks =
2 * MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
const dim3 n_blocks = {blocks_x, blocks_y, 1};
auto scratch_space = ctx.requested[batchnorm::kTempSpace].get_space_typed<gpu, 1, AccReal>(
mshadow::Shape1(C * blocks_y * 2), s);
auto stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int nthreads_phase1 = 512;
constexpr int nthreads_phase2 = 128;
FrozenBatchNormalizationBackwardKernelCLastPhase1<nthreads_phase1, DType, AccReal, LType>
<<<n_blocks, nthreads_phase1, 0, stream>>>(input.dptr_,
gradOutput.dptr_,
scratch_space.dptr_,
gradInput.dptr_,
gamma,
tensors.runningMean.dptr_,
tensors.runningVar.dptr_,
M,
C,
eps,
flags_copy);
const int nblocks_phase2 = ceil_div(C, nthreads_phase2);
FrozenBatchNormalizationBackwardKernelCLastPhase2<AccReal>
<<<nblocks_phase2, nthreads_phase2, 0, stream>>>(scratch_space.dptr_,
tensors.runningVar.dptr_,
tensors.gradWeight.dptr_,
tensors.gradBias.dptr_,
C,
blocks_y,
eps,
flags_copy);
});
} else {
dim3 blocks(gradOutput.ChannelCount());
int ltype = mxnet::common::cuda::get_load_type(gradOutput.InnerSize() * sizeof(DType));
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
const index_t NHW_div_nvec = gradOutput.OuterSize() * gradOutput.InnerSize() / nvec;
constexpr int threads = 512;
FrozenBatchNormalizationBackwardKernel<threads, DType, AccReal, LType>
<<<blocks, threads, 0, mshadow::Stream<gpu>::GetStream(s)>>>(input.dptr_,
gradOutput.dptr_,
gradInput.dptr_,
tensors.gradWeight.dptr_,
tensors.gradBias.dptr_,
gamma,
tensors.runningMean.dptr_,
tensors.runningVar.dptr_,
gradOutput.OuterSize(),
gradOutput.InnerSize(),
gradOutput.ChannelCount(),
NHW_div_nvec,
eps,
flags_copy);
});
}
}
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward);
}
} // namespace cuda
} // namespace batchnorm
template <typename xpu, typename DType, typename AccReal>
static inline uint32_t SetupFlags(const OpContext& ctx,
const BatchNormParam& params,
const std::vector<OpReqType>& req) {
uint32_t flags = 0;
flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
} else if (req[batchnorm::kData] == kAddTo) {
flags |= ADDTO_DATA_FLAG;
}
if (IsBNWriting(req[batchnorm::kGamma])) {
flags |= WRITE_GAMMA_FLAG;
} else if (req[batchnorm::kGamma] == kAddTo) {
flags |= ADDTO_GAMMA_FLAG;
}
if (IsBNWriting(req[batchnorm::kBeta])) {
flags |= WRITE_BETA_FLAG;
} else if (req[batchnorm::kBeta] == kAddTo) {
flags |= ADDTO_BETA_FLAG;
}
return flags;
}
/*! \brief Forward batch-norm pass on GPU */
template <typename xpu, typename DType, typename AccReal>
void BatchNormForwardImpl(mshadow::Stream<gpu>* stream,
const OpContext& ctx,
const BatchNormParam& param_,
const std::vector<TBlob>& in_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& out_data,
const std::vector<TBlob>& aux_states) {
batchnorm::cuda::BatchNormalizationUpdateOutput<DType, AccReal>(
stream,
ctx,
param_,
in_data,
out_data,
aux_states,
SetupFlags<xpu, DType, AccReal>(ctx, param_, req),
param_.momentum,
param_.eps);
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoForward_gpu);
}
/*! \brief Backward batch-norm pass on GPU */
template <typename xpu, typename DType, typename AccReal>
void BatchNormBackwardImpl(mshadow::Stream<gpu>* stream,
const OpContext& ctx,
const BatchNormParam& param_,
const std::vector<TBlob>& out_grad,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& in_grad,
const std::vector<TBlob>& aux_states) {
batchnorm::cuda::BatchNormalizationBackward<DType, AccReal>(
stream,
ctx,
param_,
out_grad,
in_data,
out_data,
in_grad,
aux_states,
SetupFlags<xpu, DType, AccReal>(ctx, param_, req),
param_.momentum,
param_.eps);
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}
template <>
void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
BatchNormParam param = nnvm::get<BatchNormParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 5U);
std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
std::vector<TBlob> aux_states(inputs.begin() + 3, inputs.end());
int dtype = inputs[0].type_flag_;
mxnet::TShape shape = inputs[0].shape_;
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[batchnorm::kData])) {
CudnnBatchNormForward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
})
}
#else
MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
});
#endif
}
template <>
void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 8U);
BatchNormParam param = nnvm::get<BatchNormParam>(attrs.parsed);
int dtype = inputs[0].type_flag_;
mxnet::TShape shape = inputs[0].shape_;
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) {
CudnnBatchNormBackward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
})
}
#else
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
});
#endif
}
NNVM_REGISTER_OP(BatchNorm).set_attr<FCompute>("FCompute<gpu>", BatchNormCompute<gpu>);
NNVM_REGISTER_OP(_backward_BatchNorm)
.set_attr<FCompute>("FCompute<gpu>", BatchNormGradCompute<gpu>);
} // namespace op
} // namespace mxnet