blob: 59cbfa05a075ece3faa1f985e4e7576b0ba2c5eb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file sync_batch_norm.cc
* \brief Synchronized BatchNorm modified from BatchNormV1
* \author Hang Zhang
*/
#include "sync_batch_norm-inl.h"
#include <nnvm/op_attr_types.h>
namespace mxnet {
namespace op {
template <>
Operator* CreateOp<cpu>(SyncBatchNormParam param, int dtype) {
return new SyncBatchNorm<cpu>(param);
}
// DO_BIND_DISPATCH comes from operator_common.h
Operator* SyncBatchNormProp::CreateOperatorEx(Context ctx,
mxnet::ShapeVector* in_shape,
std::vector<int>* in_type) const {
mxnet::ShapeVector out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
}
DMLC_REGISTER_PARAMETER(SyncBatchNormParam);
MXNET_REGISTER_OP_PROPERTY(_contrib_SyncBatchNorm, SyncBatchNormProp)
.describe(R"code(Batch normalization.
Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
well as offset ``beta``.
Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_.
Assume the input has more than one dimension and we normalize along axis 1.
We first compute the mean and variance along this axis:
.. math::
data\_mean[i] = mean(data[:,i,:,...]) \\
data\_var[i] = var(data[:,i,:,...])
Then compute the normalized output, which has the same shape as input, as following:
.. math::
out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i]
Both *mean* and *var* returns a scalar by treating the input as a vector.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and
``data_var`` as well, which are needed for the backward pass.
Besides the inputs and the outputs, this operator accepts two auxiliary
states, ``moving_mean`` and ``moving_var``, which are *k*-length
vectors. They are global statistics for the whole dataset, which are updated
by::
moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
moving_var = moving_var * momentum + data_var * (1 - momentum)
If ``use_global_stats`` is set to be true, then ``moving_mean`` and
``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute
the output. It is often used during inference.
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
then set ``gamma`` to 1 and its gradient to 0.
Reference:
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \
deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \
Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
)code" ADD_FILELINE)
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
.add_argument("beta", "NDArray-or-Symbol", "beta array")
.add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input")
.add_argument("moving_var", "NDArray-or-Symbol", "running variance of input")
.add_arguments(SyncBatchNormParam::__FIELDS__());
NNVM_REGISTER_OP(_contrib_SyncBatchNorm)
.add_alias("_npx_sync_batch_norm")
.set_attr<nnvm::FSetInputVarAttrOnCompose>(
"FSetInputVarAttrOnCompose",
[](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) {
if (var->attrs.dict.find("__init__") != var->attrs.dict.end())
return;
if (index == 3) {
var->attrs.dict["__init__"] = "[\"zero\", {}]";
} else if (index == 4) {
var->attrs.dict["__init__"] = "[\"one\", {}]";
}
});
} // namespace op
} // namespace mxnet