blob: 0ef5733f9f8cf385b1980ad14610da5051630248 [file] [log] [blame]
/*!
* Copyright (c) 2015 by Contributors
* \file batch_norm.cc
* \brief
* \author Bing Xu, Chris Olivier
*/
#include "batch_norm-inl.h"
#include <nnvm/op_attr_types.h>
#if MXNET_USE_MKL2017 == 1
#include <mkl_memory.h>
#include "./mkl/mkl_memory-inl.h"
#include "./mkl/mkl_batch_norm-inl.h"
#endif // MXNET_USE_MKL2017
/*! \brief inverse standard deviation <-> variance */
#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$)))
#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))
namespace mxnet {
namespace op {
namespace batchnorm {
/*! \brief Global disable of batchnorm mkl operator for unit testing */
volatile bool disable_mkl = false;
/*! \brief Fast-foreach when you don't care about the position other than channel */
template<typename DType, typename OnData>
static inline void ForEachFast(const BNTensor3<DType> &tensor,
const size_t channel,
OnData onData) {
const size_t num = tensor.OuterSize();
const size_t matrixSize = tensor.InnerSize();
const size_t skipLength = tensor.SkipLengthToNextSameChannelData();
const size_t startOffset = tensor.StartOffset(channel);
DType *data = tensor.dptr_ + startOffset;
for (size_t outer = 0; outer < num; ++outer) {
for (size_t i = 0; i < matrixSize; ++i) {
onData(data++);
}
data += skipLength;
}
}
/*! \brief Fast-foreach when you don't care about the position other than channel */
template<typename DType1, typename DType2, typename OnData>
static inline void ForEachFast(const BNTensor3<DType1> &in_data,
const BNTensor3<DType2> &out_data,
const size_t channel,
OnData onData) {
const size_t num = in_data.OuterSize();
const size_t matrixSize = in_data.InnerSize();
const size_t skipLength = in_data.SkipLengthToNextSameChannelData();
const size_t startOffset = in_data.StartOffset(channel);
DType1 *data = in_data.dptr_ + startOffset;
DType2 *odata = out_data.dptr_ + startOffset;
for (size_t outer = 0; outer < num; ++outer) {
for (size_t i = 0; i < matrixSize; ++i) {
onData(data++, odata++);
}
data += skipLength;
odata += skipLength;
}
}
} // namespace batchnorm
/*! \brief Forward CPU */
template <typename xpu, typename DType, typename AccReal>
void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<cpu> *,
const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) {
// Input
batchnorm::BNTensor3<DType> inputData(in_data[batchnorm::kData], param_.axis);
const TBlob &weights = in_data[batchnorm::kGamma];
const TBlob &bias = in_data[batchnorm::kBeta];
// Aux (Moving)
const TBlob &runningMean = aux_states[batchnorm::kMovingMean];
const TBlob &runningVariance = aux_states[batchnorm::kMovingVar];
// Output
batchnorm::BNTensor3<DType> outputData(out_data[batchnorm::kOut], param_.axis);
const TBlob &meanVector = out_data[batchnorm::kMean];
const TBlob &varianceVector = out_data[batchnorm::kVar];
AccReal *mean = meanVector.dptr<AccReal>();
AccReal *var = varianceVector.dptr<AccReal>();
const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats;
const size_t channelCount = inputData.ChannelCount();
const size_t itemCountPerChannel = inputData.Size() / channelCount;
#pragma omp parallel for
for (int channel = 0; channel < channelCount; ++channel) {
if (is_train_and_not_global_stats) {
// compute mean per input
mean[channel] = 0;
ForEachFast(inputData, channel, [mean, channel](const DType *in_data) {
mean[channel] += *in_data; });
mean[channel] /= itemCountPerChannel;
// compute variance per input
const AccReal thisMean = mean[channel];
var[channel] = 0;
ForEachFast(inputData, channel,
[var, thisMean, channel](const DType *current_in_data) {
const AccReal current = *current_in_data;
var[channel] += (current - thisMean) * (current - thisMean);
});
const AccReal sum = var[channel];
AccReal invstd;
if (sum == 0 && param_.eps == 0.0) {
// Nobody likes to divide by zero
invstd = 0;
} else {
const AccReal variance = sum / itemCountPerChannel;
invstd = VARIANCE_TO_INVSTD(variance, param_.eps);
}
var[channel] = invstd;
} else {
const AccReal *rm = runningMean.dptr<AccReal>();
const AccReal *rv = runningVariance.dptr<AccReal>();
mean[channel] = rm[channel];
var[channel] = VARIANCE_TO_INVSTD(rv[channel], param_.eps);
}
// compute output
AccReal *w = weights.dptr<AccReal>();
const AccReal *b = bias.dptr<AccReal>();
const AccReal thisMean = mean[channel];
const AccReal thisInvstd = var[channel];
const AccReal thisWeight = w[channel];
const AccReal thisBias = b[channel];
// note that var is still invstd
if (!param_.fix_gamma) {
if (IsWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
});
}
} else {
if (IsWriting(req[batchnorm::kGamma])) {
w[channel] = AccReal(1);
}
if (IsWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) + thisBias);
});
}
}
}
}
template <typename xpu, typename DType, typename AccReal>
void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *,
const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_states) {
// Input Data
batchnorm::BNTensor3<DType> inputData(in_data[batchnorm::kData], param_.axis);
const TBlob &weights = in_data[batchnorm::kGamma];
// Input Grad
batchnorm::BNTensor3<DType> gradIn(in_grad[batchnorm::kData], param_.axis);
const TBlob &gradWeight = in_grad[batchnorm::kGamma];
const TBlob &gradBias = in_grad[batchnorm::kBeta];
// Aux (Moving)
const TBlob &runningMean = aux_states[batchnorm::kMovingMean];
const TBlob &runningVariance = aux_states[batchnorm::kMovingVar];
// Output
batchnorm::BNTensor3<DType> gradOut(out_grad[batchnorm::kOut], param_.axis);
const TBlob &saveMean = out_data[batchnorm::kMean];
const TBlob &saveStd = out_data[batchnorm::kVar];
const size_t channelCount = inputData.ChannelCount();
const size_t itemCount = inputData.Size() / channelCount;
// Avoid multiple dptr() call within the channel loop
AccReal *runningMeanDataPtr = runningMean.dptr<AccReal>();
AccReal *runningVarDataPtr = runningVariance.dptr<AccReal>();
const AccReal *saveMeanDataPtr = saveMean.dptr<AccReal>();
const AccReal *saveInvStdDataPtr = saveStd.dptr<AccReal>();
AccReal *gradWeightData = gradWeight.dptr<AccReal>();
AccReal *gradBiasData = gradBias.dptr<AccReal>();
const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats;
#pragma omp parallel for
for (int channel = 0; channel < static_cast<int>(channelCount); ++channel) {
const AccReal *weight = weights.dptr<AccReal>();
const AccReal w = weight ? weight[channel] : AccReal(1);
AccReal mean, invstd;
if (is_train_and_not_global_stats) {
mean = saveMeanDataPtr[channel];
invstd = saveInvStdDataPtr[channel];
const AccReal variance = INVSTD_TO_VARIANCE(invstd, param_.eps);
// update running averages
runningMeanDataPtr[channel] = runningMeanDataPtr[channel] * param_.momentum
+ mean * (AccReal(1) - param_.momentum);
runningVarDataPtr[channel] = runningVarDataPtr[channel] * param_.momentum
+ variance * (AccReal(1) - param_.momentum);
} else {
mean = runningMeanDataPtr[channel];
invstd = VARIANCE_TO_INVSTD(runningVarDataPtr[channel], param_.eps);
}
// sumGradOut over all gradOutput in feature plane
AccReal sumGradOut = 0;
ForEachFast(gradOut, static_cast<size_t>(channel),
[&sumGradOut](const DType *gradOut_data) {
sumGradOut += *gradOut_data;
});
// dot product of the Q(X) and gradOuput
AccReal dotp = 0;
ForEachFast(inputData, gradOut, static_cast<size_t>(channel),
[&dotp, mean](const DType *thisInputData, const DType *gradOut_data) {
dotp += (*thisInputData - mean) * (*gradOut_data);
});
if (!gradIn.IsEmpty() && IsWriting(req[batchnorm::kData])) { // if there's a grad input
if (is_train_and_not_global_stats) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
// Y = Q(X) / σ ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w
// projection of gradOutput on to output scaled by std
const AccReal k = dotp * invstd * invstd / itemCount;
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
*gradIn_data = (*inputDataPtr - mean) * k;
});
const AccReal iw = invstd * w;
const AccReal gradMean = sumGradOut / itemCount;
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
});
} else {
// when in evaluation mode
// Q(X) = X - running_mean ; i.e. input centered to zero mean
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
const AccReal iw = invstd * w;
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = *gradOut_data * iw;
});
}
}
// May want to make this a param eventually
const AccReal scale = 1.0f;
if (IsWriting(req[batchnorm::kGamma])) {
if (!param_.fix_gamma) {
gradWeightData[channel] = scale * dotp * invstd;
} else {
gradWeightData[channel] = AccReal(0);
}
}
if (IsWriting(req[batchnorm::kBeta])) {
gradBiasData[channel] = scale * sumGradOut;
}
}
}
template<>
Operator *CreateOp<cpu>(BatchNormParam param, const int dtype, const TShape& shape) {
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
Operator *op = nullptr;
#if MXNET_USE_MKL2017 == 1
if (shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& !mxnet::op::batchnorm::disable_mkl) {
switch (dtype) {
case mshadow::kFloat32:
op = new MKLBatchNormOp<cpu, float>(param);
break;
case mshadow::kFloat64:
op = new MKLBatchNormOp<cpu, double>(param);
break;
default:
// MKL operator doesn't support half_t, so fall through
break;
}
}
#define BATCHNORM_LOG_MKL_INFO() \
do { \
if (!mxnet::op::batchnorm::disable_mkl) { \
LOG(INFO) << MKLBatchNormOp<cpu, float>::getName() \
<< " Skipping MKL optimization (unsupported dimension, axis or type)"; \
} \
} while (0)
#else
#define BATCHNORM_LOG_MKL_INFO() ((void)0)
#endif
if (!op) {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype,
DType,
AccReal, {
BATCHNORM_LOG_MKL_INFO();
op = new BatchNormOp<cpu, DType, AccReal>(param); });
}
return op;
}
// DO_BIND_DISPATCH comes from operator_common.h
Operator *BatchNormProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
CHECK_GE(in_shape->size(), 1U);
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_shape)[0]);
}
DMLC_REGISTER_PARAMETER(BatchNormParam);
MXNET_REGISTER_OP_PROPERTY(BatchNorm, BatchNormProp)
.describe(R"code(Batch normalization.
Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
well as offset ``beta``.
Assume the input has more than one dimension and we normalize along axis 1.
We first compute the mean and variance along this axis:
.. math::
data\_mean[i] = mean(data[:,i,:,...]) \\
data\_var[i] = var(data[:,i,:,...])
Then compute the normalized output, which has the same shape as input, as following:
.. math::
out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i]
Both *mean* and *var* returns a scalar by treating the input as a vector.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
``data_var`` as well, which are needed for the backward pass.
Besides the inputs and the outputs, this operator accepts two auxiliary
states, ``moving_mean`` and ``moving_var``, which are *k*-length
vectors. They are global statistics for the whole dataset, which are updated
by::
moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
moving_var = moving_var * momentum + data_var * (1 - momentum)
If ``use_global_stats`` is set to be true, then ``moving_mean`` and
``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute
the output. It is often used during inference.
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel
axis to be the last item in the input shape.
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
then set ``gamma`` to 1 and its gradient to 0.
)code" ADD_FILELINE)
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
.add_argument("beta", "NDArray-or-Symbol", "beta array")
.add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input")
.add_argument("moving_var", "NDArray-or-Symbol", "running variance of input")
.add_arguments(BatchNormParam::__FIELDS__());
NNVM_REGISTER_OP(BatchNorm)
.set_attr<nnvm::FSetInputVarAttrOnCompose>(
"FSetInputVarAttrOnCompose",
[](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) {
if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return;
if (index == 3) {
var->attrs.dict["__init__"] = "[\"zero\", {}]";
} else if (index == 4) {
var->attrs.dict["__init__"] = "[\"one\", {}]";
}
});
} // namespace op
} // namespace mxnet