blob: dc4914bf24573bc25aa17e2cac43bc9697649d54 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file layer_norm-inl.h
* \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
*/
#ifndef MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
#define MXNET_OPERATOR_NN_LAYER_NORM_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mshadow/base.h>
#include <map>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include "../mshadow_op.h"
#include "../operator_common.h"
#include "../mxnet_op.h"
#include "../tensor/broadcast_reduce_op.h"
namespace mxnet {
namespace op {
namespace layernorm {
enum LayerNormOpInputs {kData, kGamma, kBeta}; // kGamma: scaling parameters, kBeta: shift biases
enum LayerNormOpOutputs {kOut, kMean, kStd}; // req, out_data
} // namespace layernorm
struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
int axis;
float eps;
bool output_mean_var;
DMLC_DECLARE_PARAMETER(LayerNormParam) {
DMLC_DECLARE_FIELD(axis).set_default(-1)
.describe("The axis to perform layer normalization. "
"Usually, this should be be axis of the channel dimension. "
"Negative values means indexing from right to left.");
DMLC_DECLARE_FIELD(eps).set_default(1e-5f)
.describe("An `epsilon` parameter to prevent division by 0.");
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
.describe("Output the mean and std calculated along the given axis.");
}
};
template<typename xpu>
void LayerNormCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);
int axis = param.axis;
if (axis < 0) {
axis += static_cast<int>(inputs[0].ndim());
}
CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
CHECK_EQ(inputs.size(), 3U);
Stream<xpu> *s = ctx.get_stream<xpu>();
// Reshape gamma and beta to be broadcastable
mxnet::TShape new_param_shape(inputs[0].shape_.begin(), inputs[0].shape_.end());
for (int i = 0; i < inputs[0].ndim(); i++) {
if (i != axis) {
new_param_shape[i] = 1;
}
}
const TBlob gamma = inputs[1].reshape(new_param_shape);
const TBlob beta = inputs[2].reshape(new_param_shape);
// Compute necessary data for the reduce operation.
mxnet::TShape red_src_shape, red_dst_shape;
BroadcastReduceShapeCompact(inputs[0].shape_, outputs[layernorm::kMean].shape_,
&red_src_shape, &red_dst_shape);
const TBlob in_data = inputs[0].reshape(red_src_shape);
const TBlob mean_data = outputs[layernorm::kMean].reshape(red_dst_shape);
const TBlob std_data = outputs[layernorm::kStd].reshape(red_dst_shape);
int channel_size = red_src_shape.Size() / red_dst_shape.Size();
// Initialize the workspace
Tensor<xpu, 1, char> workspace;
size_t workspace_size = 0;
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
workspace_size =
broadcast::ReduceWorkspaceSize<NDim, DType>(s, mean_data.shape_, req[0], in_data.shape_);
});
});
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
// Calculate mean
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, mean_data, req[0], workspace, in_data);
Tensor<xpu, 1, DType> mean_data_tensor = mean_data.FlatTo1D<xpu, DType>(s);
mean_data_tensor /= scalar<DType>(channel_size);
});
});
// Calculate data = data - mean
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{inputs[0], outputs[layernorm::kMean]},
{kWriteTo}, {outputs[0]});
// Calculate std
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::square>(
s, std_data, req[0], workspace, centered_out);
Tensor<xpu, 1, DType> std_data_tensor = std_data.FlatTo1D<xpu, DType>(s);
std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+ scalar<DType>(param.eps));
});
});
// Calculate data = data / std
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
{kWriteTo}, {outputs[0]});
// Calculate data = data * gamma
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{outputs[0], gamma},
{kWriteTo}, {outputs[0]});
// Calculate data = data + beta
BinaryBroadcastCompute<xpu, op::mshadow_op::plus>(attrs, ctx,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
}
/*
Calculate the gradient of layer normalization.
We have the following gradient for gamma, beta and x:
\bar{x} = (x - mean) / std
w = og * r / std
grad_gamma = sum(\bar{x} og, exclude_axis)
grad_beta = sum(og, exclude_axis)
grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis)
*/
template<typename xpu>
void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 5U);
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
int axis = param.axis;
if (axis < 0) {
axis += static_cast<int>(inputs[0].ndim());
}
CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
Stream<xpu> *s = ctx.get_stream<xpu>();
// Reshape gamma to be broadcastable
mxnet::TShape new_param_shape(inputs[0].shape_.begin(), inputs[0].shape_.end());
for (int i = 0; i < inputs[0].ndim(); i++) {
if (i != axis) {
new_param_shape[i] = 1;
}
}
const TBlob ograd = inputs[0];
const TBlob data = inputs[1];
const TBlob gamma = inputs[2].reshape(new_param_shape);
const TBlob mean = inputs[3];
const TBlob std = inputs[4];
// Prepare the necessary shapes for reduction
mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
BroadcastReduceShapeCompact(ograd.shape_, mean.shape_, &red_src_shape, &red_dst_shape);
BroadcastReduceShapeCompact(ograd.shape_, gamma.shape_,
&red_exclude_src_shape, &red_exclude_dst_shape);
int channel_size = red_src_shape.Size() / red_dst_shape.Size();
// Initialize the workspace + Construct the temporary TBlobs
Tensor<xpu, 1, char> workspace;
size_t reduce_workspace_size = 0;
size_t data_size = 0;
size_t red_out_size = 0;
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
data_size = sizeof(DType) * data.Size();
red_out_size = sizeof(DType) * mean.Size();
// There are two types of reduction workloads: reduce over axis and reduce exclude axis
// We take the maximum of the workspace sizes required by these workloads.
// Also, we explicitly set the req_type=kAddto in case we want to use it.
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
reduce_workspace_size =
std::max(reduce_workspace_size,
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_src_shape,
kAddTo, red_dst_shape));
});
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
reduce_workspace_size =
std::max(reduce_workspace_size,
broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_exclude_src_shape, kAddTo,
red_exclude_dst_shape));
});
});
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
const TBlob normalized_data = TBlob(workspace.dptr_ + reduce_workspace_size,
data.shape_, data.dev_mask(), data.type_flag_, data.dev_id());
const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size,
ograd.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id());
const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
// Compute normalized_data = (data - mean) / std
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data, mean},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{normalized_data, std},
{kWriteTo}, {normalized_data});
// Calculate grad_beta
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
});
});
}
// Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis)
ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {normalized_data, ograd},
{kWriteTo}, {ograd_mult});
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
});
});
}
// Calculate grad_data:
// ograd_mult = ograd * gamma / std
// grad_data = ograd_mult - mean(ograd_mult, axis)
// + normalized_data * (-mean(normalized_data * ograd_mult, axis))
if (req[0] != kNullOp) {
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{ograd, gamma},
{kWriteTo}, {ograd_mult});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{ograd_mult, std},
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
});
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(channel_size);
});
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{ograd_mult, red_out},
{req[0]}, {outputs[0]});
ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, {ograd_mult, normalized_data},
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
});
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(- channel_size);
});
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {outputs[0]});
}
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NN_LAYER_NORM_INL_H_