fix safe acc in the general case of layer norm (#19806)
diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h
index b440e3d..2c309ff 100644
--- a/src/operator/nn/layer_norm-inl.h
+++ b/src/operator/nn/layer_norm-inl.h
@@ -124,7 +124,7 @@
// Calculate mean
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
- if (safe_acc) {
+ if (!safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, mean_data, req[0], workspace, in_data);
} else {
@@ -149,7 +149,7 @@
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
- if (safe_acc) {
+ if (!safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, false>(
s, std_data, req[0], workspace, centered_out);
} else {
@@ -290,7 +290,7 @@
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
- if (safe_acc) {
+ if (!safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
@@ -313,7 +313,7 @@
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
- if (safe_acc) {
+ if (!safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
@@ -347,7 +347,7 @@
#endif // !defined(__CUDACC__)
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
- if (safe_acc) {
+ if (!safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
@@ -375,7 +375,7 @@
#endif // !defined(__CUDACC__)
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
- if (safe_acc) {
+ if (!safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));