blob: 15979dc25b857135dfd6d71325133d9bcc18958e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file layer_norm.cu
* \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
*/
#include "./layer_norm-inl.h"
using namespace mshadow::cuda;
namespace mxnet {
namespace op {
template <>
void LayerNormGradComputeGeneralImpl<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const TBlob& ograd,
const TBlob& data,
const TBlob& gamma,
const TBlob& mean,
const TBlob& std,
const TBlob& normalized_data,
const TBlob& ograd_mult,
const TBlob& red_out,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mshadow::Tensor<gpu, 1, char>& workspace,
const mxnet::TShape& red_dst_shape,
const mxnet::TShape& red_src_shape,
const mxnet::TShape& red_exclude_dst_shape,
const mxnet::TShape& red_exclude_src_shape,
const int channel_size) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<gpu>* s = ctx.get_stream<gpu>();
// Compute normalized_data = (data - mean) / std
BinaryBroadcastRTCCompute{"sub"}( // NOLINT
attrs,
ctx,
{data, mean},
{kWriteTo},
{normalized_data});
BinaryBroadcastRTCCompute{"div"}( // NOLINT
attrs,
ctx,
{normalized_data, std},
{kWriteTo},
{normalized_data});
// Calculate grad_beta
if (req[2] != kNullOp) {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::RTCReduce(ctx,
outputs[2].reshape(red_exclude_dst_shape),
req[2],
workspace,
ograd.reshape(red_exclude_src_shape),
"red::sum{}",
NDim,
"identity");
});
}
// Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
ElemwiseBinaryRTCCompute{"mul"}( // NOLINT
attrs,
ctx,
{normalized_data, ograd},
{kWriteTo},
{ograd_mult});
if (req[1] != kNullOp) {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::RTCReduce(ctx,
outputs[1].reshape(red_exclude_dst_shape),
req[1],
workspace,
ograd_mult.reshape(red_exclude_src_shape),
"red::sum{}",
NDim,
"identity");
});
}
// Calculate grad_data:
// ograd_mult = ograd * gamma / std
// grad_data = ograd_mult - mean(ograd_mult, axis)
// + normalized_data * (-mean(normalized_data * ograd_mult, axis))
if (req[0] != kNullOp) {
BinaryBroadcastRTCCompute{"mul"}( // NOLINT
attrs,
ctx,
{ograd, gamma},
{kWriteTo},
{ograd_mult});
BinaryBroadcastRTCCompute{"div"}( // NOLINT
attrs,
ctx,
{ograd_mult, std},
{kWriteTo},
{ograd_mult});
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::RTCReduce(ctx,
red_out.reshape(red_dst_shape),
kWriteTo,
workspace,
ograd_mult.reshape(red_src_shape),
"red::sum{}",
NDim,
"identity");
});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<gpu, 1, DType> red_out_tensor = red_out.FlatTo1D<gpu, DType>(s);
red_out_tensor /= scalar<DType>(channel_size);
});
BinaryBroadcastRTCCompute{"sub"}( // NOLINT
attrs,
ctx,
{ograd_mult, red_out},
{req[0]},
{outputs[0]});
ElemwiseBinaryRTCCompute{"mul"}( // NOLINT
attrs,
ctx,
{ograd_mult, normalized_data},
{kWriteTo},
{ograd_mult});
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::RTCReduce(ctx,
red_out.reshape(red_dst_shape),
kWriteTo,
workspace,
ograd_mult.reshape(red_src_shape),
"red::sum{}",
NDim,
"identity");
});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<gpu, 1, DType> red_out_tensor = red_out.FlatTo1D<gpu, DType>(s);
red_out_tensor /= scalar<DType>(-channel_size);
});
BinaryBroadcastRTCCompute{"mul"}( // NOLINT
attrs,
ctx,
{normalized_data, red_out},
{kAddTo},
{outputs[0]});
}
}
template <typename DType>
__device__ __forceinline__ DType
warp_shfl(DType value, int src_lane, int width = 32, unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_sync(mask, value, src_lane, width);
#else
return __shfl(value, src_lane, width);
#endif
}
template <typename DType>
__device__ __forceinline__ DType
warp_shfl_xor(DType value, int laneMask, int width = 32, unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
/* A single updating step of the Welford's online algorithm to calculate the mean and variance.
* The value 'curr' will be accumulated to the (mean, sigma2, count) triplet.
*
*/
template <typename DType, typename IType>
__device__ __forceinline__ void StepWelfordOnlineSum(const DType curr,
DType& mean, // NOLINT
DType& sigma2, // NOLINT
IType& count) { // NOLINT
count += IType(1);
DType delta = curr - mean;
mean += delta / count;
sigma2 += delta * (curr - mean);
}
/* Merge the mean/variance of two partitions. It's the key step of the Chan's parallel algorithm.
* The (lhs_mean, lhs_sigma2, lhs_count) will be merged into (rhs_mean, rhs_sigma2, rhs_count)
*
* See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance for more details.
*
* TODO(sxjscience) Explore the possibility of int lhs_count and rhs_count
*/
template <typename DType, typename IType>
__device__ __inline__ void ChanMergePartition(const DType lhs_mean,
const DType lhs_sigma2,
const IType lhs_count,
DType& rhs_mean, // NOLINT
DType& rhs_sigma2, // NOLINT
IType& rhs_count) { // NOLINT
DType delta = rhs_mean - lhs_mean;
DType nA = static_cast<DType>(lhs_count);
DType nB = static_cast<DType>(rhs_count);
rhs_count = nA + nB;
if (rhs_count > DType(0)) {
nA = nA / rhs_count;
nB = nB / rhs_count;
rhs_mean = nA * lhs_mean + nB * rhs_mean;
rhs_sigma2 = rhs_sigma2 + lhs_sigma2 + delta * delta * nA * nB * rhs_count;
} else {
rhs_mean = DType(0);
rhs_sigma2 = DType(0);
}
}
/* Split the input column into multiple partitions and compute the mean/sigma of each partition.
* Each thread will keep a mean/sigma2. The mean/sigma2 can be further merged to get the mean and
* sigma2 of the column.
*/
template <typename AType, typename DType, typename IType>
__device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ col_vals,
const int nchannel,
AType& mean, // NOLINT
AType& sigma2, // NOLINT
IType& count) { // NOLINT
int tid = threadIdx.x + threadIdx.y * blockDim.x;
const int nthread = blockDim.x * blockDim.y;
// Each thread takes charge of 4 consecutive numbers. This should optimize the loading speed using
// vectorized types like float4.
// Also, to minimize branch divergence, we split the for-loop into two parts.
int l = 4 * tid;
for (; l + 3 < nchannel; l += 4 * nthread) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
StepWelfordOnlineSum(static_cast<AType>(col_vals[l + i]), mean, sigma2, count);
}
}
for (; l < nchannel; ++l) {
StepWelfordOnlineSum(static_cast<AType>(col_vals[l]), mean, sigma2, count);
}
}
template <>
__device__ __forceinline__ void BlockWelfordOnlineSum<float, mshadow::half::half_t, int>(
const mshadow::half::half_t* __restrict__ col_vals,
const int nchannel,
float& mean, // NOLINT
float& sigma2, // NOLINT
int& count) { // NOLINT
int tid = threadIdx.x + threadIdx.y * blockDim.x;
const int nthread = blockDim.x * blockDim.y;
// We cast the input half pointer to half2 to optimize the loading speed.
// Here, we need to notice that CUDA forces memory alignment, i.e.,
// ASSERT static_cast<size_t>(ptr) % sizeof(dtype) == 0.
// Thus, we need to shift the address of the half pointer to be aligned by half2.
int align_shift = (reinterpret_cast<size_t>(col_vals) % 4) != 0;
int padding = (nchannel - align_shift) % 2;
int half2_size = (nchannel - align_shift) / 2;
const __half2* half2_col_vals = reinterpret_cast<const __half2*>(col_vals + align_shift);
if (threadIdx.x == 0 && threadIdx.y == 0) {
if (align_shift) {
StepWelfordOnlineSum(__half2float(col_vals[0].cuhalf_), mean, sigma2, count);
}
if (padding) {
StepWelfordOnlineSum(__half2float(col_vals[nchannel - 1].cuhalf_), mean, sigma2, count);
}
}
for (int l = tid; l < half2_size; l += nthread) {
float2 ele_val = __half22float2(half2_col_vals[l]);
StepWelfordOnlineSum(ele_val.x, mean, sigma2, count);
StepWelfordOnlineSum(ele_val.y, mean, sigma2, count);
}
}
/* Fused CUDA kernel for the forward pass of layer normalization.
* It computes the LayerNorm when axis=-1, i.e., contiguous reduction scenario.
* Shape of the input tensors:
* in_data = (nbatch, nchannel)
* gamma = (nchannel,)
* beta = (nchannel,)
* out_data = (nchannel,)
* mean_data = (nbatch,)
* var_data = (nbatch,)
* It's always launched with (blockDim.x, blockDim.y) = (WARP_SIZE, blockDim.y)
* Also, when blockDim.y > 1, it requires shared memory that has size:
* sizeof(AType) * blockDim.y + sizeof(int) * blockDim.y / 2
*/
template <typename AType, typename DType, typename IType>
__global__ void LayerNormFusedForwardKernelContig(const int nbatch,
const int nchannel,
const AType eps,
const DType* __restrict__ in_data,
const DType* __restrict__ gamma,
const DType* __restrict__ beta,
DType* __restrict__ out_data,
DType* __restrict__ mean_data,
DType* __restrict__ std_data) {
int bid = blockIdx.x + blockIdx.y * gridDim.x;
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int nthread = blockDim.x * blockDim.y;
IType count = 0;
AType mean = 0;
AType sigma2 = 0;
if (bid < nbatch) {
extern __shared__ char buf[]; // Shared memory
const DType* col_vals = in_data + bid * nchannel;
BlockWelfordOnlineSum(col_vals, nchannel, mean, sigma2, count);
// Merge the mean/sigma2 within a warp
// Use the Chan's Parallel Algorithm to merge all (mean, sigma2, counts)
// within a warp of threads.
// After calling the function, threadIdx.x == 0 will store the result of
// the aggregated (mean, sigma2, counts).
for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) {
AType meanB = warp_shfl_xor(mean, mask);
AType sigma2B = warp_shfl_xor(sigma2, mask);
IType countB = warp_shfl_xor(count, mask);
ChanMergePartition(meanB, sigma2B, countB, mean, sigma2, count);
}
if (blockDim.y > 1) {
// Inter-warp reduction. Copy the upper-half of the warps to shared memory
// and merge with the lower-half warp
AType* mean_buf = reinterpret_cast<AType*>(buf);
AType* sigma2_buf =
reinterpret_cast<AType*>(buf + sizeof(AType) * blockDim.y / 2 * blockDim.x);
IType* count_buf = reinterpret_cast<IType*>(buf + sizeof(AType) * blockDim.y * blockDim.x);
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
mean_buf[idx] = mean;
sigma2_buf[idx] = sigma2;
count_buf[idx] = count;
}
__syncthreads();
if (threadIdx.y < offset) {
const int idx = threadIdx.y * blockDim.x + threadIdx.x;
ChanMergePartition(mean_buf[idx], sigma2_buf[idx], count_buf[idx], mean, sigma2, count);
}
__syncthreads();
}
// Broadcast the result to all threads
if (threadIdx.y == 0) {
mean_buf[threadIdx.x] = mean;
sigma2_buf[threadIdx.x] = sigma2;
}
__syncthreads();
mean = mean_buf[threadIdx.x];
sigma2 = sigma2_buf[threadIdx.x] / nchannel;
} else {
sigma2 /= nchannel;
}
// Calculate the out_data: gamma * (x - mean) / sqrt(var + eps) + beta
AType std_eps = sqrt(sigma2 + eps);
AType invstd_eps = DType(1.0) / std_eps;
DType* out_col_val = out_data + bid * nchannel;
if (gamma != nullptr && beta != nullptr) {
for (int i = tid; i < nchannel; i += nthread) {
out_col_val[i] =
gamma[i] * static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean)) +
beta[i];
}
} else if (gamma == nullptr && beta != nullptr) {
for (int i = tid; i < nchannel; i += nthread) {
out_col_val[i] =
static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean)) + beta[i];
}
} else if (gamma != nullptr && beta == nullptr) {
for (int i = tid; i < nchannel; i += nthread) {
out_col_val[i] =
gamma[i] * static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean));
}
} else {
for (int i = tid; i < nchannel; i += nthread) {
out_col_val[i] = static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean));
}
}
// Write the out_data and var_data
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean_data[bid] = static_cast<DType>(mean);
std_data[bid] = static_cast<DType>(std_eps);
}
}
}
template <bool safe_acc = false>
void LayerNormGPUContig(const LayerNormParam param,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 3U);
mxnet::TShape data_shape(2, 0);
mxnet::TShape mean_shape(1, 0);
size_t in_ndim = inputs[layernorm::kData].ndim();
data_shape[0] = mean_shape[0] = inputs[layernorm::kData].shape_.ProdShape(0, in_ndim - 1);
data_shape[1] = inputs[layernorm::kData].shape_[in_ndim - 1];
const TBlob in_data = inputs[layernorm::kData].reshape(data_shape);
const TBlob gamma = inputs[layernorm::kGamma];
const TBlob beta = inputs[layernorm::kBeta];
const TBlob out_data = outputs[layernorm::kOut].reshape(data_shape);
const TBlob mean_data = outputs[layernorm::kMean].reshape(mean_shape);
const TBlob std_data = outputs[layernorm::kStd].reshape(mean_shape);
// Make sure the inputs are contiguous
CHECK_EQ(in_data.CheckContiguous(), true);
CHECK_EQ(gamma.CheckContiguous(), true);
CHECK_EQ(beta.CheckContiguous(), true);
CHECK_EQ(out_data.CheckContiguous(), true);
CHECK_EQ(mean_data.CheckContiguous(), true);
CHECK_EQ(std_data.CheckContiguous(), true);
// Lauch the kernel. The dynamic shared memory size is
// sizeof(DType) * blockDim.y * blockDim.x + sizeof(DType) * blockDim.y / 2 * blockDim.x
int nbatch = data_shape[0];
int nchannel = data_shape[1];
float eps = param.eps;
int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch;
int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1;
int nthread_y;
const dim3 dimGrid(ngrid_x, ngrid_y);
if (nchannel <= 128) {
nthread_y = 1;
} else if (nchannel <= 512) {
nthread_y = 2;
} else {
nthread_y = 4;
}
cudaStream_t stream = Stream<gpu>::GetStream(ctx.get_stream<gpu>());
const dim3 dimBlock(32, nthread_y);
MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
int nshared =
nthread_y > 1 ? nthread_y * 32 * sizeof(AType) + (nthread_y / 2) * 32 * sizeof(int) : 0;
CheckLaunchParam(dimGrid, dimBlock);
LayerNormFusedForwardKernelContig<AType, DType, int>
<<<dimGrid, dimBlock, nshared, stream>>>(nbatch,
nchannel,
static_cast<AType>(eps),
in_data.dptr<DType>(),
gamma.dptr<DType>(),
beta.dptr<DType>(),
out_data.dptr<DType>(),
mean_data.dptr<DType>(),
std_data.dptr<DType>());
});
MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedForwardKernelContig);
}
template <>
void LayerNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (req[0] == kNullOp)
return;
CHECK_NE(req[0], kAddTo);
int axis = param.axis;
if (axis < 0) {
axis += static_cast<int>(inputs[0].ndim());
}
CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
if (axis == inputs[0].ndim() - 1) {
// Try to use the accelerated CUDA kernels
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce(
"MXNET_SAFE_ACCUMULATION=1 is recommended for LayerNorm with float16 inputs. "
"See https://mxnet.apache.org/api/faq/env_var "
"for more details.");
}
if (safe_acc) {
return LayerNormGPUContig<true>(param, ctx, inputs, req, outputs);
} else {
return LayerNormGPUContig<false>(param, ctx, inputs, req, outputs);
}
}
return LayerNormComputeGeneral<gpu>(attrs, ctx, inputs, req, outputs);
}
/* Fused CUDA kernel for calculating the gradient w.r.t gamma/beta in LayerNorm when axis=-1
* (Contiguous case).
* The gradient of gamma and beta are:
* d_gamma = sum(out_grad * (x - mean) / std, axis=0)
* d_beta = sum(out_grad, axis=0)
*
* We compute the gradient (mainly reduction over a non-contiguous axis) using two steps to
* improve the parallelism.
*
* In the first step, we divide the rows uniformly into K parts. K independent threadblocks are used
* to calculate the partial reduction result of each part. Illustrated below:
*
* 1st Block 2nd Block 3rd Block k-th Block
* | --------------- | ---------------- | --------------- | ... | ---------------- |
* | --------------- | ---------------- | --------------- | ... | ---------------- |
* | --------------- | ---------------- | --------------- | ... | ---------------- |
* | --------------- | ---------------- | --------------- | ... | ---------------- |
* part_gamma[0] part_gamma[1] part_gamma[2] part_gamma[k-1]
* part_beta[0] part_beta[1] part_beta[2] part_beta[k-1]
*
*
* In the second step, we sum up the row-values in part_gamma and part_beta.
*
* This `LayerNormFusedBackwardKernel_PartGammaBeta` function implements the first step and
* `LayerNormFusedBackwardKernel_GammaBeta` implements the second step.
*/
template <typename AType, typename DType>
__global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch,
const int nchannel,
const DType* __restrict__ in_data,
const DType* __restrict__ out_grad,
const DType* __restrict__ mean_data,
const DType* __restrict__ std_data,
AType* __restrict__ part_gamma_grad,
AType* __restrict__ part_beta_grad) {
extern __shared__ char buf[];
AType* d_buf = reinterpret_cast<AType*>(buf);
const int npart = gridDim.y;
const int block_row_num = (nbatch + npart - 1) / npart;
// The rows are divided into `npart` parts. Each threadblock calculates the reduction result
// within the corresponding row ranges.
int row_stride = blockDim.x + 1;
const int c = blockIdx.x * blockDim.x + threadIdx.x;
int r_begin = blockIdx.y * block_row_num;
int r_end = min((blockIdx.y + 1) * block_row_num, nbatch);
AType* buf_gamma_grad = d_buf;
AType* buf_beta_grad = d_buf + blockDim.y * row_stride;
AType local_gamma_grad = 0;
AType local_beta_grad = 0;
if (c < nchannel) {
for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y) {
int r = r_b + threadIdx.y;
if (r < r_end) {
AType local_mean = static_cast<AType>(mean_data[r]);
AType local_std = static_cast<AType>(std_data[r]);
int read_idx = r * nchannel + c;
AType local_in_data = static_cast<AType>(in_data[read_idx]);
AType local_out_grad = static_cast<AType>(out_grad[read_idx]);
local_gamma_grad += (local_in_data - local_mean) / local_std * local_out_grad;
local_beta_grad += local_out_grad;
}
}
}
buf_gamma_grad[threadIdx.y * row_stride + threadIdx.x] = local_gamma_grad;
buf_beta_grad[threadIdx.y * row_stride + threadIdx.x] = local_beta_grad;
__syncthreads();
for (int offset = blockDim.y / 2; offset > 1; offset >>= 1) {
if (threadIdx.y < offset) {
int idx1 = threadIdx.y * row_stride + threadIdx.x;
int idx2 = (threadIdx.y + offset) * row_stride + threadIdx.x;
buf_gamma_grad[idx1] += buf_gamma_grad[idx2];
buf_beta_grad[idx1] += buf_beta_grad[idx2];
}
__syncthreads();
}
if (threadIdx.y == 0 && c < nchannel) {
part_gamma_grad[blockIdx.y * nchannel + c] =
buf_gamma_grad[threadIdx.x] + buf_gamma_grad[threadIdx.x + row_stride];
part_beta_grad[blockIdx.y * nchannel + c] =
buf_beta_grad[threadIdx.x] + buf_beta_grad[threadIdx.x + row_stride];
}
}
template <bool gamma_addto, bool beta_addto, typename AType, typename DType>
__global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch,
const int nchannel,
const int npart,
const AType* __restrict__ part_gamma_grad,
const AType* __restrict__ part_beta_grad,
DType* gamma_grad,
DType* beta_grad) {
const int c = blockIdx.x * blockDim.x + threadIdx.x;
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
if (c < nchannel) {
extern __shared__ char buf[];
AType* buf_gamma_grad = reinterpret_cast<AType*>(buf);
AType* buf_beta_grad = reinterpret_cast<AType*>(buf) + blockDim.x * blockDim.y;
buf_gamma_grad[tid] = 0;
buf_beta_grad[tid] = 0;
for (int r = threadIdx.y; r < npart; r += blockDim.y) {
buf_gamma_grad[tid] += part_gamma_grad[r * nchannel + c];
buf_beta_grad[tid] += part_beta_grad[r * nchannel + c];
}
__syncthreads();
// Begin for inter-warp reduce
if (npart > 1) {
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y < offset) {
int idx1 = tid;
int idx2 = tid + offset * blockDim.x;
buf_gamma_grad[idx1] += buf_gamma_grad[idx2];
buf_beta_grad[idx1] += buf_beta_grad[idx2];
}
__syncthreads();
}
}
if (threadIdx.y == 0) {
if (gamma_grad) {
if (gamma_addto) {
gamma_grad[c] += static_cast<DType>(buf_gamma_grad[threadIdx.x]);
} else {
gamma_grad[c] = static_cast<DType>(buf_gamma_grad[threadIdx.x]);
}
}
if (beta_grad) {
if (beta_addto) {
beta_grad[c] += static_cast<DType>(buf_beta_grad[threadIdx.x]);
} else {
beta_grad[c] = static_cast<DType>(buf_beta_grad[threadIdx.x]);
}
}
}
}
}
/*
*
*
*/
template <int LOAD_UNROLL, bool data_addto, typename AType, typename DType>
__global__ void LayerNormFusedBackwardKernel_Data(const int nbatch,
const int nchannel,
const DType* __restrict__ in_data,
const DType* __restrict__ out_grad,
const DType* __restrict__ mean_data,
const DType* __restrict__ std_data,
const DType* __restrict__ gamma,
DType* data_grad) {
int bid = blockIdx.x + blockIdx.y * gridDim.x;
const int nthread = blockDim.x * blockDim.y;
if (bid < nbatch) {
// Shared memory with size blockDim.y * blockDim.x * sizeof(DType)
extern __shared__ char buf[];
int tid = threadIdx.x + threadIdx.y * blockDim.x;
// 1. Calculate: mean(out_grad * gamma / std, axis=-1)
// mean(out_grad * gamma / std * (x - mean) / std, axis=-1)
AType sum_val0 = 0; // Stores mean(out_grad * gamma / std, axis=-1)
AType sum_val1 = 0; // Stores mean(out_grad * gamma / std * (x - mean) / std, axis=-1)
AType mean = static_cast<AType>(mean_data[bid]);
AType invstd_eps = AType(1) / static_cast<AType>(std_data[bid]);
int l = LOAD_UNROLL * tid;
for (; l + LOAD_UNROLL - 1 < nchannel; l += nthread * LOAD_UNROLL) {
#pragma unroll
for (int i = 0; i < LOAD_UNROLL; ++i) {
AType ele_og = static_cast<AType>(out_grad[bid * nchannel + l + i]);
AType ele_x = static_cast<AType>(in_data[bid * nchannel + l + i]);
AType ele_gamma = static_cast<AType>(gamma[l + i]);
sum_val0 += ele_og * ele_gamma * invstd_eps;
sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps;
}
}
for (; l < nchannel; ++l) {
AType ele_og = static_cast<AType>(out_grad[bid * nchannel + l]);
AType ele_x = static_cast<AType>(in_data[bid * nchannel + l]);
AType ele_gamma = static_cast<AType>(gamma[l]);
sum_val0 += ele_og * ele_gamma * invstd_eps;
sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps;
}
// Intra-warp reduction (all-reduce)
for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) {
sum_val0 += warp_shfl_xor(sum_val0, mask);
sum_val1 += warp_shfl_xor(sum_val1, mask);
}
// Inter-warp reduction (all-reduce)
if (blockDim.y > 1) {
AType* sum_val0_buf = reinterpret_cast<AType*>(buf);
AType* sum_val1_buf =
reinterpret_cast<AType*>(buf + blockDim.y / 2 * blockDim.x * sizeof(AType));
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
sum_val0_buf[idx] = sum_val0;
sum_val1_buf[idx] = sum_val1;
}
__syncthreads();
if (threadIdx.y < offset) {
const int idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_val0 += sum_val0_buf[idx];
sum_val1 += sum_val1_buf[idx];
}
__syncthreads();
}
if (threadIdx.y == 0) {
sum_val0_buf[threadIdx.x] = sum_val0;
sum_val1_buf[threadIdx.x] = sum_val1;
}
__syncthreads();
sum_val0 = sum_val0_buf[threadIdx.x];
sum_val1 = sum_val1_buf[threadIdx.x];
}
sum_val0 /= nchannel;
sum_val1 /= nchannel;
// 2. Calculate the gradient as
// out_grad * gamma / std - sum_val0 - (x - mean) / std * sum_val1
for (int l = tid; l < nchannel; l += nthread) {
AType ele_out_grad = static_cast<AType>(out_grad[bid * nchannel + l]);
AType ele_x = static_cast<AType>(in_data[bid * nchannel + l]);
AType ele_gamma = static_cast<AType>(gamma[l]);
if (data_addto) {
data_grad[bid * nchannel + l] +=
static_cast<DType>(ele_out_grad * ele_gamma * invstd_eps - sum_val0 -
(ele_x - mean) * invstd_eps * sum_val1);
} else {
data_grad[bid * nchannel + l] =
static_cast<DType>(ele_out_grad * ele_gamma * invstd_eps - sum_val0 -
(ele_x - mean) * invstd_eps * sum_val1);
}
}
}
}
void GetGammaBetaGradKernelParams(const int nbatch,
const int nchannel,
dim3* part_grad_block_dim,
dim3* part_grad_grid_dim,
dim3* gb_block_dim,
dim3* gb_grid_dim,
int* npart) {
*npart = 16;
*part_grad_block_dim = dim3(32, 16);
*part_grad_grid_dim = dim3((nchannel + 32 - 1) / 32, *npart);
*gb_block_dim = dim3(32, *npart);
*gb_grid_dim = dim3((nchannel + 32 - 1) / 32);
CheckLaunchParam(*part_grad_grid_dim, *part_grad_block_dim);
CheckLaunchParam(*gb_grid_dim, *gb_block_dim);
}
template <bool safe_acc = false>
void LayerNormGradGPUContig(const LayerNormParam param,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
#if MXNET_USE_ONEDNN == 1
CHECK_EQ(inputs.size(), 6U); // additional beta tensor
#else
CHECK_EQ(inputs.size(), 5U);
#endif
const TBlob out_grad = inputs[0];
const TBlob in_data = inputs[1];
const TBlob gamma = inputs[2];
const TBlob mean_data = inputs[3];
const TBlob std_data = inputs[4];
const TBlob data_grad = outputs[0];
const TBlob gamma_grad = outputs[1];
const TBlob beta_grad = outputs[2];
// Make sure the inputs are contiguous
CHECK_EQ(out_grad.CheckContiguous(), true);
CHECK_EQ(in_data.CheckContiguous(), true);
CHECK_EQ(gamma.CheckContiguous(), true);
CHECK_EQ(mean_data.CheckContiguous(), true);
CHECK_EQ(std_data.CheckContiguous(), true);
int nbatch = in_data.shape_.ProdShape(0, in_data.ndim() - 1);
int nchannel = in_data.shape_[in_data.ndim() - 1];
int data_grad_req = req[0];
int gamma_grad_req = req[1];
int beta_grad_req = req[2];
CHECK_NE(data_grad_req, kWriteInplace);
CHECK_NE(gamma_grad_req, kWriteInplace);
CHECK_NE(beta_grad_req, kWriteInplace);
Stream<gpu>* s = ctx.get_stream<gpu>();
cudaStream_t stream = Stream<gpu>::GetStream(s);
// Calculate the gradient for gamma/beta
CHECK_EQ(gamma_grad.CheckContiguous(), true);
CHECK_EQ(beta_grad.CheckContiguous(), true);
dim3 part_grad_block_dim, part_grad_grid_dim, gb_block_dim, gb_grid_dim;
int npart;
GetGammaBetaGradKernelParams(nbatch,
nchannel,
&part_grad_block_dim,
&part_grad_grid_dim,
&gb_block_dim,
&gb_grid_dim,
&npart);
if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) {
MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
Tensor<gpu, 1, AType> workspace =
ctx.requested[0].get_space_typed<gpu, 1, AType>(Shape1(2 * npart * nchannel), s);
AType* part_gamma_grad_ptr = workspace.dptr_;
AType* part_beta_grad_ptr = workspace.dptr_ + npart * nchannel;
const int nshared_K1 =
2 * (part_grad_block_dim.x + 1) * part_grad_block_dim.y * sizeof(AType);
const int nshared_K2 = 2 * gb_block_dim.x * gb_block_dim.y * sizeof(AType);
DType* gamma_grad_ptr = (gamma_grad_req != kNullOp) ? gamma_grad.dptr<DType>() : nullptr;
DType* beta_grad_ptr = (beta_grad_req != kNullOp) ? beta_grad.dptr<DType>() : nullptr;
LayerNormFusedBackwardKernel_PartGammaBeta<<<part_grad_grid_dim,
part_grad_block_dim,
nshared_K1,
stream>>>(nbatch,
nchannel,
in_data.dptr<DType>(),
out_grad.dptr<DType>(),
mean_data.dptr<DType>(),
std_data.dptr<DType>(),
part_gamma_grad_ptr,
part_beta_grad_ptr);
MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_PartGammaBeta);
if (gamma_grad_req == kAddTo && beta_grad_req != kAddTo) {
LayerNormFusedBackwardKernel_GammaBeta<true, false>
<<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>(nbatch,
nchannel,
npart,
part_gamma_grad_ptr,
part_beta_grad_ptr,
gamma_grad_ptr,
beta_grad_ptr);
} else if (gamma_grad_req != kAddTo && beta_grad_req == kAddTo) {
LayerNormFusedBackwardKernel_GammaBeta<false, true>
<<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>(nbatch,
nchannel,
npart,
part_gamma_grad_ptr,
part_beta_grad_ptr,
gamma_grad_ptr,
beta_grad_ptr);
} else if (gamma_grad_req == kAddTo && beta_grad_req == kAddTo) {
LayerNormFusedBackwardKernel_GammaBeta<true, true>
<<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>(nbatch,
nchannel,
npart,
part_gamma_grad_ptr,
part_beta_grad_ptr,
gamma_grad_ptr,
beta_grad_ptr);
} else {
LayerNormFusedBackwardKernel_GammaBeta<false, false>
<<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>(nbatch,
nchannel,
npart,
part_gamma_grad_ptr,
part_beta_grad_ptr,
gamma_grad_ptr,
beta_grad_ptr);
}
});
MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_GammaBeta);
}
// Calculate the gradient for data
CHECK_EQ(data_grad.CheckContiguous(), true);
int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch;
int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1;
const dim3 data_grid_dim(ngrid_x, ngrid_y);
int nthread_y;
if (nchannel <= 32) {
nthread_y = 1;
} else if (nchannel <= 128) {
nthread_y = 2;
} else if (nchannel <= 512) {
nthread_y = 4;
} else {
nthread_y = 8;
}
const dim3 data_block_dim(32, nthread_y);
const int LOAD_UNROLL = 4;
if (data_grad_req != kNullOp) {
MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(AType) : 0;
CheckLaunchParam(data_grid_dim, data_block_dim);
if (data_grad_req == kAddTo) {
LayerNormFusedBackwardKernel_Data<LOAD_UNROLL, true, AType>
<<<data_grid_dim, data_block_dim, nshared, stream>>>(nbatch,
nchannel,
in_data.dptr<DType>(),
out_grad.dptr<DType>(),
mean_data.dptr<DType>(),
std_data.dptr<DType>(),
gamma.dptr<DType>(),
data_grad.dptr<DType>());
} else {
LayerNormFusedBackwardKernel_Data<LOAD_UNROLL, false, AType>
<<<data_grid_dim, data_block_dim, nshared, stream>>>(nbatch,
nchannel,
in_data.dptr<DType>(),
out_grad.dptr<DType>(),
mean_data.dptr<DType>(),
std_data.dptr<DType>(),
gamma.dptr<DType>(),
data_grad.dptr<DType>());
}
});
MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_Data);
}
}
template <>
void LayerNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
int axis = param.axis;
if (axis < 0) {
axis += static_cast<int>(inputs[0].ndim());
}
CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
if (axis == inputs[0].ndim() - 1) {
// Use the accelerated CUDA kernels
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true);
if (safe_acc) {
return LayerNormGradGPUContig<true>(param, ctx, inputs, req, outputs);
} else {
return LayerNormGradGPUContig<false>(param, ctx, inputs, req, outputs);
}
}
return LayerNormGradComputeGeneral<gpu>(attrs, ctx, inputs, req, outputs);
}
NNVM_REGISTER_OP(LayerNorm).set_attr<FCompute>("FCompute<gpu>", LayerNormCompute<gpu>);
NNVM_REGISTER_OP(_backward_LayerNorm)
.set_attr<FCompute>("FCompute<gpu>", LayerNormGradCompute<gpu>);
} // namespace op
} // namespace mxnet