[v1.6.x][Bug Fixed] Fix batch norm when grad_req is `add` (#18518) (#18714)
* [Bug Fixed] Fix batch norm when grad_req is `add` (#18500)
* fix batch norm when fix_gamma is True
* support gradient accumulation for batch norm
* mkldnn batchnorm support grad add
* unittest for bn
* fix bn arg
* fix lint
* fix mkldnn
* fix mkldnn bn
* fix grad when fixing gamma
* fix naive gpu bn
* fix lint
* fix cudnn bn
* fix flag
* fix lint
* fix testcase
* fix
* use @pytest.mark.parametrize
* combination
* remove redundant test in batchnorm
* npx.batch_norm test
* try to fix test
* reduce the number of tests for batchnorm
* fix
* Revert "[Bug Fixed] Fix batch norm when grad_req is `add` (#18500)"
This reverts commit 8e32cd6959461290c1698e02466fcc16f61ad237.
* [v1.x] backport #18500 - [Bug Fixed] Fix batch norm when grad_req is `add` (#18518)
* Fix batch norm when grad_req is
* fix
* remove softmax test
* fix
* add copy_size
* Fix init method for TestBatchNorm
Co-authored-by: JackieWu <wkcn@live.cn>
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 17a16db..485b3b3 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -259,6 +259,7 @@
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 8U);
CHECK_EQ(outputs.size(), 3U);
+
std::vector<TBlob> out_grad(1);
std::vector<TBlob> out_data(3);
std::vector<TBlob> in_data(3);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 3214e3b..fc65476 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -85,6 +85,31 @@
}
}
+template<typename DType1, typename DType2, typename DType3, typename OnData>
+static inline void ForEachFast(const BNTensor3<DType1> &in_data,
+ const BNTensor3<DType2> &in_data2,
+ const BNTensor3<DType3> &out_data,
+ const size_t channel,
+ OnData onData) {
+ const size_t num = in_data.OuterSize();
+ const size_t matrixSize = in_data.InnerSize();
+ const size_t skipLength = in_data.SkipLengthToNextSameChannelData();
+ const size_t startOffset = in_data.StartOffset(channel);
+
+ DType1 *data = in_data.dptr_ + startOffset;
+ DType2 *data2 = in_data2.dptr_ + startOffset;
+ DType3 *odata = out_data.dptr_ + startOffset;
+
+ for (size_t outer = 0; outer < num; ++outer) {
+ for (size_t i = 0; i < matrixSize; ++i) {
+ onData(data++, data2++, odata++);
+ }
+ data += skipLength;
+ data2 += skipLength;
+ odata += skipLength;
+ }
+}
+
} // namespace batchnorm
/*! \brief Forward CPU */
@@ -264,7 +289,7 @@
dotp += (*thisInputData - mean) * (*gradOut_data);
});
- if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input
+ if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input
if (is_train_and_not_global_stats) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
@@ -273,44 +298,60 @@
// projection of gradOutput on to output scaled by std
const AccReal k = dotp * invstd * invstd / itemCount;
- ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
- [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
- *gradIn_data = (*inputDataPtr - mean) * k;
- });
-
const AccReal iw = invstd * w;
const AccReal gradMean = sumGradOut / itemCount;
- ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
- [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
- *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
- });
+ if (req[batchnorm::kData] != kAddTo) {
+ ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
+ [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
+ *gradIn_data = (*inputDataPtr - mean) * k;
+ });
+
+ ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
+ [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
+ *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
+ });
+ } else {
+ ForEachFast(inputData, gradOut, gradIn, static_cast<size_t>(channel),
+ [&mean, &k, iw, gradMean](const DType *inputDataPtr,
+ const DType *gradOut_data,
+ DType *gradIn_data) {
+ DType normal_val = (*inputDataPtr - mean) * k;
+ *gradIn_data += (*gradOut_data - gradMean -
+ normal_val) * iw;
+ });
+ }
} else {
// when in evaluation mode
// Q(X) = X - running_mean ; i.e. input centered to zero mean
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
const AccReal iw = invstd * w;
- ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
- [iw](const DType *gradOut_data, DType *gradIn_data) {
- *gradIn_data = *gradOut_data * iw;
- });
+ if (req[batchnorm::kData] != kAddTo) {
+ ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
+ [iw](const DType *gradOut_data, DType *gradIn_data) {
+ *gradIn_data = *gradOut_data * iw;
+ });
+ } else {
+ ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
+ [iw](const DType *gradOut_data, DType *gradIn_data) {
+ *gradIn_data += *gradOut_data * iw;
+ });
+ }
}
}
// May want to make this a param eventually
const AccReal scale = 1.0f;
- if (IsBNWriting(req[batchnorm::kGamma])) {
- if (!param_.fix_gamma) {
- gradWeightData[channel] = scale * dotp * invstd;
- } else {
+ if (!param_.fix_gamma) {
+ KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd);
+ } else {
+ if (IsBNWriting(req[batchnorm::kGamma])) {
gradWeightData[channel] = AccReal(0);
}
}
- if (IsBNWriting(req[batchnorm::kBeta])) {
- gradBiasData[channel] = scale * sumGradOut;
- }
+ KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut);
}
}
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index be9309c..7b36d25 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -34,6 +34,9 @@
#define FIX_GAMMA_FLAG 8
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32
+#define ADDTO_DATA_FLAG (1 << 6)
+#define ADDTO_GAMMA_FLAG (1 << 7)
+#define ADDTO_BETA_FLAG (1 << 8)
#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
@@ -362,33 +365,60 @@
* momentum + localVariance * (AccReal(1) - momentum);
}
- if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) {
- for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
- for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
- const DType gradOut = gradOutput.get_ref(batch, plane, x);
- if (is_train_and_not_global_stats) {
- const DType inp = input.get_ref(batch, plane, x);
- const AccReal proj = (inp - mean) * projScale;
- gradInput.get_ref(batch, plane, x) =
- ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
- } else {
- gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
- gradOut * gradScale);
+ if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) {
+ const bool grad_write = flags & WRITE_DATA_FLAG;
+ if (grad_write) {
+ for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
+ for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
+ const DType gradOut = gradOutput.get_ref(batch, plane, x);
+ if (is_train_and_not_global_stats) {
+ const DType inp = input.get_ref(batch, plane, x);
+ const AccReal proj = (inp - mean) * projScale;
+ gradInput.get_ref(batch, plane, x) =
+ ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
+ } else {
+ gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
+ gradOut * gradScale);
+ }
+ }
+ }
+ } else {
+ // grad addto
+ for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
+ for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
+ const DType gradOut = gradOutput.get_ref(batch, plane, x);
+ if (is_train_and_not_global_stats) {
+ const DType inp = input.get_ref(batch, plane, x);
+ const AccReal proj = (inp - mean) * projScale;
+ gradInput.get_ref(batch, plane, x) +=
+ ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
+ } else {
+ gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
+ gradOut * gradScale);
+ }
}
}
}
}
- if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) {
+ if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 &&
+ (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
- tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
+ if (flags & WRITE_GAMMA_FLAG)
+ tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
+ else
+ tensors.gradWeight[plane] += ScalarConvert<AccReal, DType>::to(dotP * invstd);
} else {
tensors.gradWeight[plane] = DType(0);
}
}
- if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
- tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+ if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 &&
+ (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) {
+ if (flags & WRITE_BETA_FLAG)
+ tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+ else
+ tensors.gradBias[plane] += ScalarConvert<AccReal, DType>::to(gradOutputSum);
}
}
@@ -585,12 +615,18 @@
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
+ } else if (req[batchnorm::kData] == kAddTo) {
+ flags |= ADDTO_DATA_FLAG;
}
if (IsBNWriting(req[batchnorm::kGamma])) {
flags |= WRITE_GAMMA_FLAG;
+ } else if (req[batchnorm::kGamma] == kAddTo) {
+ flags |= ADDTO_GAMMA_FLAG;
}
if (IsBNWriting(req[batchnorm::kBeta])) {
flags |= WRITE_BETA_FLAG;
+ } else if (req[batchnorm::kBeta] == kAddTo) {
+ flags |= ADDTO_BETA_FLAG;
}
return flags;
}
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index 3fc9119..5dad073 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -208,13 +208,24 @@
if (param_.fix_gamma) gamma = 1.f;
+ bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) ||
+ (req[cudnnbatchnorm::kBeta] == kAddTo);
+ if (grad_add_gamma_beta) {
+ if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
+ dgamma = 0.f;
+ }
+ if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
+ dbeta = 0.f;
+ }
+ }
+
CUDNN_CALL(cudnnBatchNormalizationBackward(
s->dnn_handle_,
mode,
&a,
- &b,
+ req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
&a,
- req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
+ grad_add_gamma_beta ? &b_add : &b, // gamma and beta
io_desc_,
x.dptr_,
io_desc_,
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 26637c7..2e0fb64 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -317,13 +317,15 @@
else if (diff.IsDefaultData())
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
- auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
+ auto gradi_mem = CreateMKLDNNMem(const_cast<NDArray &>(gradIn),
+ bwd.pd.diff_src_desc(), req[batchnorm::kData]);
if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
+ const size_t copy_size = sizeof(DType) * channels_;
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
weight_buf[i] = (gamma.data().dptr<DType>())[i]; // weight
@@ -337,7 +339,7 @@
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
- net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
+ net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second;
net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
@@ -362,26 +364,46 @@
}
net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
net_args[MKLDNN_ARG_VARIANCE] = var_mem;
- MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
- MKLDNNStream::Get()->Submit();
} else {
net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData());
net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
- MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
- MKLDNNStream::Get()->Submit();
}
+ MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
+ CommitOutput(gradIn, gradi_mem);
+ MKLDNNStream::Get()->Submit();
// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
- for (int i = 0; i < channels_; i++) {
- if (!param.fix_gamma)
- (in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
- else
+ DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
+ DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
+
+ // the gradient of gamma
+ if (!param.fix_gamma) {
+ if (req[batchnorm::kGamma] != kNullOp) {
+ if (req[batchnorm::kGamma] != kAddTo) {
+ memcpy(w_grad_1, gw_buf, copy_size);
+ } else {
+ for (int i = 0; i < channels_; i++) {
+ w_grad_1[i] += gw_buf[i];
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < channels_; i++) {
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
+ }
}
- for (int i = 0; i < channels_; i++) {
- (in_grad[2].data().dptr<DType>())[i] = gw_buf[i + channels_];
+ // the gradient of beta
+ if (req[batchnorm::kBeta] != kNullOp) {
+ if (req[batchnorm::kBeta] != kAddTo) {
+ memcpy(w_grad_2, &gw_buf[channels_], copy_size);
+ } else {
+ DType *grad_beta = &gw_buf[channels_];
+ for (int i = 0; i < channels_; i++) {
+ w_grad_2[i] += grad_beta[i];
+ }
+ }
}
} else {
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1ff1b61..fe2df9e 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1073,6 +1073,163 @@
@with_seed()
@use_np
+def test_npx_batch_norm():
+ momentum = 0.9
+ epsilon = 1e-5
+ class TestBatchNorm(HybridBlock):
+ def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs):
+ super(TestBatchNorm, self).__init__()
+ self.eps = eps
+ self.fix_gamma = fix_gamma
+ self.momentum = momentum
+ self.kwargs = kwargs
+ def hybrid_forward(self, F, data, bn_gamma, bn_beta,
+ bn_running_mean, bn_running_var):
+ op = F.npx.batch_norm
+ output = op(data, bn_gamma, bn_beta,
+ bn_running_mean, bn_running_var,
+ momentum=self.momentum, eps=self.eps,
+ fix_gamma=self.fix_gamma, **self.kwargs)
+ return output
+
+ def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var,
+ axis,
+ data_grad_req, gamma_grad_req, beta_grad_req):
+ kwargs = dict(output_mean_var=output_mean_var)
+ kwargs.update(dict(axis=axis, cudnn_off=cudnn_off))
+ op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs)
+ nch = shape[axis]
+
+ if not fix_gamma:
+ bn_gamma = np.random.uniform(size=(nch,))
+ bn_gamma.attach_grad(grad_req=gamma_grad_req)
+ else:
+ bn_gamma = np.ones((nch,))
+
+ bn_beta = np.random.uniform(size=(nch,))
+ bn_beta.attach_grad(grad_req=beta_grad_req)
+
+ bn_running_mean = np.zeros(nch)
+ bn_running_var = np.ones(nch)
+
+ running_mean = np.zeros(nch)
+ running_var = np.ones(nch)
+ num_iters = 10
+ expand_shape = [1] * len(shape)
+ expand_shape[axis] = shape[axis]
+ expand_shape = tuple(expand_shape)
+ data = np.random.uniform(size=shape)
+ data.attach_grad(grad_req=data_grad_req)
+ adX, adW, adb = 0, 0, 0
+ is_train = data_grad_req != 'null' or \
+ (not fix_gamma and gamma_grad_req != 'null') or \
+ beta_grad_req != 'null'
+ for _ in range(num_iters):
+ if data_grad_req != 'add':
+ data = np.random.uniform(size=shape)
+ data.attach_grad(grad_req=data_grad_req)
+ ograd = np.random.uniform(size=shape)
+ with mx.autograd.record():
+ output = op(data, bn_gamma, bn_beta,
+ bn_running_mean, bn_running_var)
+ if output_mean_var:
+ output, output_mean, output_std = output
+ if is_train:
+ output.backward(ograd)
+ mx.nd.waitall()
+
+ assert 0 <= axis < data.ndim
+ reduce_axis = tuple(i for i in range(data.ndim) if i != axis)
+ assert len(reduce_axis) == data.ndim - 1
+ data_mean = data.mean(
+ axis=reduce_axis, keepdims=True)
+ data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis,
+ keepdims=True)
+
+ target_output = (data - data_mean) / \
+ np.sqrt(data_var + epsilon) * \
+ bn_gamma.reshape(expand_shape) + \
+ bn_beta.reshape(expand_shape)
+
+ # squeeze data_mean and data_var
+ data_mean_flat = data_mean.squeeze()
+ data_var_flat = data_var.squeeze()
+
+ running_mean = running_mean * momentum + \
+ data_mean_flat * (1 - momentum)
+ running_var = running_var * momentum + \
+ data_var_flat * (1 - momentum)
+
+ W = bn_gamma.reshape(expand_shape)
+ dnx = ograd * W
+ xsm = data - data_mean
+ nd = 1.0 / np.sqrt(data_var + epsilon)
+ nx = xsm * nd
+ m = _np.prod(shape) / shape[axis]
+ dvar = np.sum(dnx * xsm, axis=reduce_axis, keepdims=True,
+ ) * (-0.5) * np.power(nd, 3)
+ dmean = -nd * np.sum(dnx, axis=reduce_axis, keepdims=True) - \
+ dvar * xsm.mean(axis=reduce_axis, keepdims=True,
+ ) * 2.0
+ dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
+ dW = np.sum(ograd * nx, axis=reduce_axis)
+ db = np.sum(ograd, axis=reduce_axis)
+ adX = dX if data_grad_req != 'add' else adX + dX
+ adW = dW if gamma_grad_req != 'add' else adW + dW
+ adb = db if beta_grad_req != 'add' else adb + db
+
+ atol, rtol = 5e-2, 5e-2
+
+ if output_mean_var:
+ assert_almost_equal(output_mean.asnumpy(),
+ data_mean_flat.asnumpy(),
+ atol=atol, rtol=rtol)
+ assert_almost_equal(output_std.asnumpy(),
+ (1.0 / np.sqrt(data_var_flat +
+ epsilon)).asnumpy(),
+ atol=atol, rtol=rtol)
+ assert_almost_equal(output.asnumpy(), target_output.asnumpy(),
+ atol=atol, rtol=rtol)
+ if is_train:
+ assert_almost_equal(bn_running_mean.asnumpy(
+ ), running_mean.asnumpy(), atol=atol, rtol=rtol)
+ assert_almost_equal(bn_running_var.asnumpy(
+ ), running_var.asnumpy(), atol=atol, rtol=rtol)
+
+ if data_grad_req != 'null':
+ assert_almost_equal(data.grad.asnumpy(),
+ adX.asnumpy(), atol=atol, rtol=rtol)
+ if not fix_gamma:
+ if gamma_grad_req != 'null':
+ assert_almost_equal(
+ bn_gamma.grad.asnumpy(), adW.asnumpy(),
+ atol=atol, rtol=rtol)
+ else:
+ assert((bn_gamma.asnumpy() == 1).all())
+ if beta_grad_req != 'null':
+ assert_almost_equal(
+ bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)
+
+ shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
+ bools = [False, True]
+ for shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
+ shapes, bools, bools, bools):
+ grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add']
+ for data_grad_req in grad_reqs:
+ for gamma_grad_req in grad_reqs:
+ if fix_gamma and gamma_grad_req != 'null':
+ continue
+ for beta_grad_req in grad_reqs:
+ for axis in range(len(shape)):
+ _test_batchnorm_impl(
+ shape, fix_gamma, cudnn_off, output_mean_var,
+ axis,
+ data_grad_req,
+ gamma_grad_req, beta_grad_req)
+
+
+@with_seed()
+@use_np
def test_npi_boolean_assign():
class TestBooleanAssignScalar(HybridBlock):
def __init__(self, val):
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 39fd16d..0dcb476 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1826,11 +1826,18 @@
momentum = 0.9
epsilon = 1e-5
- def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var):
- print(str((op, shape, axis, cudnn_off)))
+ def _test_batchnorm_impl(op_name, shape, fix_gamma, cudnn_off, output_mean_var,
+ axis,
+ data_grad_req, gamma_grad_req, beta_grad_req):
+ if op_name == 'BatchNorm':
+ op = mx.nd.BatchNorm
+ elif op_name == 'SyncBatchNorm':
+ op = mx.nd.contrib.SyncBatchNorm
+ else:
+ raise ValueError('Not supported {}'.format(op_name))
kwargs = dict(output_mean_var=output_mean_var)
- if op == mx.nd.contrib.SyncBatchNorm:
+ if op_name == 'SyncBatchNorm':
if axis != 1:
return
key = str(op) + str(shape) + str(axis)
@@ -1841,11 +1848,14 @@
kwargs.update(dict(axis=axis, cudnn_off=cudnn_off))
nch = shape[axis]
- bn_gamma = mx.nd.random.uniform(shape=(nch,))
- bn_gamma.attach_grad()
+ if not fix_gamma:
+ bn_gamma = mx.nd.random.uniform(shape=(nch,))
+ bn_gamma.attach_grad(grad_req=gamma_grad_req)
+ else:
+ bn_gamma = mx.nd.ones(shape=(nch,))
bn_beta = mx.nd.random.uniform(shape=(nch,))
- bn_beta.attach_grad()
+ bn_beta.attach_grad(grad_req=beta_grad_req)
bn_running_mean = mx.nd.zeros(nch)
bn_running_var = mx.nd.ones(nch)
@@ -1855,18 +1865,26 @@
num_iters = 10
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]
+ data = mx.nd.random.uniform(shape=shape)
+ data.attach_grad(grad_req=data_grad_req)
+ adX, adW, adb = 0, 0, 0
+ is_train = data_grad_req != 'null' or \
+ (not fix_gamma and gamma_grad_req != 'null') or \
+ beta_grad_req != 'null'
for _ in range(num_iters):
- data = mx.nd.random.uniform(shape=shape)
- data.attach_grad()
+ if data_grad_req != 'add':
+ data = mx.nd.random.uniform(shape=shape)
+ data.attach_grad(grad_req=data_grad_req)
ograd = mx.nd.random.uniform(shape=shape)
with mx.autograd.record():
output = op(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var,
momentum=momentum, eps=epsilon,
- fix_gamma=False, **kwargs)
+ fix_gamma=fix_gamma, **kwargs)
if output_mean_var:
output, output_mean, output_std = output
- output.backward(ograd)
+ if is_train:
+ output.backward(ograd)
mx.nd.waitall()
data_mean = data.mean(
@@ -1903,9 +1921,11 @@
dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
dW = (ograd * nx).sum(axis=axis, exclude=True)
db = ograd.sum(axis=axis, exclude=True)
+ adX = dX if data_grad_req != 'add' else adX + dX
+ adW = dW if gamma_grad_req != 'add' else adW + dW
+ adb = db if beta_grad_req != 'add' else adb + db
- atol = 1e-2
- rtol = 1e-2
+ atol, rtol = 5e-2, 5e-2
if output_mean_var:
assert_almost_equal(output_mean.asnumpy(),
@@ -1922,26 +1942,43 @@
atol=atol, rtol=rtol)
assert_almost_equal(output.asnumpy(), target_output.asnumpy(),
atol=atol, rtol=rtol)
- assert_almost_equal(bn_running_mean.asnumpy(
- ), running_mean.asnumpy(), atol=atol, rtol=rtol)
- assert_almost_equal(bn_running_var.asnumpy(
- ), running_var.asnumpy(), atol=atol, rtol=rtol)
+ if is_train:
+ assert_almost_equal(bn_running_mean.asnumpy(
+ ), running_mean.asnumpy(), atol=atol, rtol=rtol)
+ assert_almost_equal(bn_running_var.asnumpy(
+ ), running_var.asnumpy(), atol=atol, rtol=rtol)
- assert_almost_equal(data.grad.asnumpy(),
- dX.asnumpy(), atol=atol, rtol=rtol)
- assert_almost_equal(
- bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol)
- assert_almost_equal(
- bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol)
+ if data_grad_req != 'null':
+ assert_almost_equal(data.grad.asnumpy(),
+ adX.asnumpy(), atol=atol, rtol=rtol)
+ if not fix_gamma:
+ if gamma_grad_req != 'null':
+ assert_almost_equal(
+ bn_gamma.grad.asnumpy(), adW.asnumpy(),
+ atol=atol, rtol=rtol)
+ else:
+ assert((bn_gamma.asnumpy() == 1).all())
+ if beta_grad_req != 'null':
+ assert_almost_equal(
+ bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol)
- for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]:
- for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]:
- for axis in range(len(shape)):
- for cudnn_off in [False, True]:
- for output_mean_var in [False, True]:
- _test_batchnorm_impl(op, shape, axis,
- cudnn_off, output_mean_var)
-
+ op_names = ['BatchNorm', 'SyncBatchNorm']
+ shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)]
+ bools = [False, True]
+ for op_name, shape, fix_gamma, cudnn_off, output_mean_var in itertools.product(
+ op_names, shapes, bools, bools, bools):
+ grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add']
+ for data_grad_req in grad_reqs:
+ for gamma_grad_req in grad_reqs:
+ if fix_gamma and gamma_grad_req != 'null':
+ continue
+ for beta_grad_req in grad_reqs:
+ for axis in range(len(shape)):
+ _test_batchnorm_impl(
+ op_name, shape, fix_gamma, cudnn_off, output_mean_var,
+ axis,
+ data_grad_req,
+ gamma_grad_req, beta_grad_req)
@with_seed()
def test_groupnorm():