Use cuDNN for conv bias and bias grad (#20771)

* Use cuDNN for conv bias and bias grad

* Environment variables to use native add-bias and bias-grad

* Handle 3D tensors in cuDNN legacy API

* Fix AMP for ndarray.numpy

* Remove env vars, used for benchmarking

Co-authored-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
diff --git a/python/mxnet/amp/loss_scaler.py b/python/mxnet/amp/loss_scaler.py
index 1e464ff..46a8ed9 100644
--- a/python/mxnet/amp/loss_scaler.py
+++ b/python/mxnet/amp/loss_scaler.py
@@ -46,14 +46,14 @@
         """Check gradients for overflow."""
         if is_np_array():
             all_finite_f = ndarray.numpy._internal.multi_all_finite
-            ones_f = ndarray.numpy.ones
+            ones_f = lambda ctx: ndarray.numpy.ones((1,), device=ctx)
         else:
             all_finite_f = ndarray.multi_all_finite
-            ones_f = ndarray.ones
+            ones_f = lambda ctx: ndarray.ones((1,), ctx=ctx)
         with ag.pause():
             chunk_size = 200
             valid_params = [p._grad[0] for p in params if p._grad is not None]
-            gpu_output = ones_f((1,), ctx=valid_params[0].context)
+            gpu_output = ones_f(valid_params[0].context)
             nb_params = len(valid_params)
             for idx in range(0, nb_params, chunk_size):
                 all_finite_f(*valid_params[idx:idx+chunk_size],
diff --git a/src/common/cuda/cudnn_cxx.cc b/src/common/cuda/cudnn_cxx.cc
index 8e161b4..2259c85 100644
--- a/src/common/cuda/cudnn_cxx.cc
+++ b/src/common/cuda/cudnn_cxx.cc
@@ -112,15 +112,6 @@
   return ret;
 }
 
-std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
-                                   const std::vector<int64_t>& dims) {
-  CHECK_EQ(order.size(), dims.size());
-  std::vector<int64_t> ret(dims.size(), 1);
-  for (size_t i = dims.size() - 1; i--;)
-    ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
-  return ret;
-}
-
 std::vector<Descriptor> GetPlans(cudnnBackendHeurMode_t h_mode,
                                  cudnnHandle_t handle,
                                  const Descriptor& op_graph,
diff --git a/src/common/cuda/cudnn_cxx.h b/src/common/cuda/cudnn_cxx.h
index 0379a5d..07cd93d 100644
--- a/src/common/cuda/cudnn_cxx.h
+++ b/src/common/cuda/cudnn_cxx.h
@@ -244,8 +244,14 @@
                                      cudnnBackendDescriptorType_t type);
 
 // Order sets layout, as a permutation of dims, with N,C,<spacial dims> being identity.
-std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
-                                   const std::vector<int64_t>& dims);
+template <typename T>
+std::vector<T> PackedStrides(const std::vector<size_t>& order, const std::vector<T>& dims) {
+  CHECK_EQ(order.size(), dims.size());
+  std::vector<T> ret(dims.size(), 1);
+  for (size_t i = dims.size() - 1; i--;)
+    ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
+  return ret;
+}
 
 // Given an engine config's `notes`, return whether that config is compatible, i.e. does
 // the config have all of the required notes and none of the notes that are being excluded.
diff --git a/src/operator/cudnn_ops.cc b/src/operator/cudnn_ops.cc
index e7e649f..2b99dc7 100644
--- a/src/operator/cudnn_ops.cc
+++ b/src/operator/cudnn_ops.cc
@@ -29,12 +29,10 @@
 
 #include <dmlc/parameter.h>
 
-#include <algorithm>
 #include <cstdlib>
 #include <iomanip>
 #include <iterator>
 #include <limits>
-#include <numeric>
 #include <sstream>
 #include <string>
 #include <utility>
@@ -79,10 +77,6 @@
   return channel_last ? 1 + n_space_dims : 1;
 }
 
-std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const {
-  return PackedStrides(Order(), dims);
-}
-
 LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) {
   static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{
       {mshadow::kNCW, {1, false}},
@@ -165,14 +159,8 @@
   for (size_t i = 0; i < dims.size(); ++i)
     dims[i] = blob.shape_[rev_order[i]];
   auto strides = li.Strides(dims);
-  if (li.n_space_dims == 1 && expand_1d) {
-    dims.insert(dims.begin() + 2, 1);
-    std::vector<size_t> order(dims.size());
-    std::iota(order.begin(), order.end(), 0);
-    if (li.channel_last)
-      std::rotate(order.begin() + 1, order.begin() + 2, order.end());
-    strides = PackedStrides(order, dims);
-  }
+  if (expand_1d)
+    li.ExpandIf1d(&dims, &strides);
   return MakeTensorDesc(
       uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual);
 }
@@ -758,6 +746,109 @@
   CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan.get(), var_pack.get()));
 }
 
+struct LegacyTensorDestroyer {
+  using pointer = cudnnTensorDescriptor_t;
+
+  void operator()(cudnnTensorDescriptor_t desc) {
+    CUDNN_CALL_NONFATAL(cudnnDestroyTensorDescriptor(desc));
+  }
+};
+
+using LegacyTensor = std::unique_ptr<cudnnTensorDescriptor_t, LegacyTensorDestroyer>;
+
+LegacyTensor MakeLegacyTensor() {
+  cudnnTensorDescriptor_t desc{};
+  CUDNN_CALL(cudnnCreateTensorDescriptor(&desc));
+  return LegacyTensor(desc);
+}
+
+union ScalingParam {
+  double d;
+  float f;
+};
+
+std::pair<ScalingParam, ScalingParam> AlphaBeta(int type_flag, double init_a, double init_b) {
+  ScalingParam a, b;
+  switch (type_flag) {
+    case kFloat64:
+      a.d = init_a;
+      b.d = init_b;
+      break;
+    case kFloat32:  // fallthrough
+    case kFloat16:
+      a.f = init_a;
+      b.f = init_b;
+      break;
+    default:
+      LOG(FATAL) << "Unexpected type: " << type_flag;
+  }
+  return {a, b};
+}
+
+void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const LayoutInfo& li) {
+  std::vector<int> dims(blob.shape_.ndim());
+  CHECK_EQ(dims.size(), li.n_space_dims + 2);
+  auto rev_order = ReverseOrder(li.Order());
+  for (size_t i = 0; i < dims.size(); ++i)
+    dims[i] = blob.shape_[rev_order[i]];
+  auto strides = li.Strides(dims);
+  li.ExpandIf1d(&dims, &strides);
+  auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
+  CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
+}
+
+void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc,
+                                const TBlob& blob,
+                                const LayoutInfo& li) {
+  std::vector<int> dims(li.n_space_dims + 2, 1);
+  dims[1] = blob.shape_[0];
+  std::vector<int> strides(dims.size(), 1);
+  strides[0] = blob.shape_[0];
+  li.ExpandIf1d(&dims, &strides);
+  auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
+  CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
+}
+
+bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b) {
+  thread_local auto y_desc = MakeLegacyTensor();
+  thread_local auto b_desc = MakeLegacyTensor();
+
+  auto s             = ctx.get_stream<gpu>();
+  auto [alpha, beta] = AlphaBeta(y.type_flag_, 1.0, 1.0);  // NOLINT(whitespace/braces)
+
+  SetLegacyTensor(y_desc.get(), y, li);
+  SetLegacyCTensorExpandDims(b_desc.get(), b, li);
+
+  auto err =
+      cudnnAddTensor(s->dnn_handle_, &alpha, b_desc.get(), b.dptr_, &beta, y_desc.get(), y.dptr_);
+  if (err == CUDNN_STATUS_NOT_SUPPORTED)
+    return false;
+  CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
+  return true;
+}
+
+bool LegacyBiasGrad(const OpContext& ctx,
+                    const LayoutInfo& li,
+                    bool add_to,
+                    const TBlob& db,
+                    const TBlob& dy) {
+  thread_local auto db_desc = MakeLegacyTensor();
+  thread_local auto dy_desc = MakeLegacyTensor();
+
+  auto s             = ctx.get_stream<gpu>();
+  auto [alpha, beta] = AlphaBeta(dy.type_flag_, 1.0, add_to ? 1.0 : 0.0);  // NOLINT(*)
+
+  SetLegacyCTensorExpandDims(db_desc.get(), db, li);
+  SetLegacyTensor(dy_desc.get(), dy, li);
+
+  auto err = cudnnConvolutionBackwardBias(
+      s->dnn_handle_, &alpha, dy_desc.get(), dy.dptr_, &beta, db_desc.get(), db.dptr_);
+  if (err == CUDNN_STATUS_NOT_SUPPORTED)
+    return false;
+  CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
+  return true;
+}
+
 }  // namespace cudnn
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/cudnn_ops.h b/src/operator/cudnn_ops.h
index 60b45ad..5f24a7e 100644
--- a/src/operator/cudnn_ops.h
+++ b/src/operator/cudnn_ops.h
@@ -29,7 +29,9 @@
 
 #include <mxnet/op_attr_types.h>
 
+#include <algorithm>
 #include <mutex>
+#include <numeric>
 #include <tuple>
 #include <unordered_map>
 #include <utility>
@@ -89,7 +91,23 @@
 
   std::vector<size_t> Order() const;
   size_t ChannelIdx() const;
-  std::vector<int64_t> Strides(const std::vector<int64_t>& dims) const;
+
+  template <typename T>
+  std::vector<T> Strides(const std::vector<T>& dims) const {
+    return cudnn_cxx::PackedStrides(Order(), dims);
+  }
+
+  template <typename T>
+  void ExpandIf1d(std::vector<T>* dims, std::vector<T>* strides) const {
+    if (n_space_dims != 1)
+      return;
+    dims->insert(dims->begin() + 2, 1);
+    std::vector<size_t> order(dims->size());
+    std::iota(order.begin(), order.end(), 0);
+    if (channel_last)
+      std::rotate(order.begin() + 1, order.begin() + 2, order.end());
+    *strides = cudnn_cxx::PackedStrides(order, *dims);
+  }
 };
 
 LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout);
@@ -246,6 +264,14 @@
                    const TBlob& dw);
 };
 
+bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b);
+
+bool LegacyBiasGrad(const OpContext& ctx,
+                    const LayoutInfo& li,
+                    bool add_to,
+                    const TBlob& db,
+                    const TBlob& dy);
+
 }  // namespace cudnn
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu
index 74cb872..84d6f5f 100644
--- a/src/operator/nn/convolution.cu
+++ b/src/operator/nn/convolution.cu
@@ -57,14 +57,18 @@
     if (ok && !param.no_bias) {
       CHECK_EQ(inputs[conv::kBias].shape_.ndim(), 1);
       auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
-      int k       = inputs[conv::kBias].shape_.Size();
-      auto b      = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
-      BinaryBroadcastRTCCompute{"add"}(  // NOLINT(whitespace/braces)
-          attrs,
-          ctx,
-          {outputs[conv::kOut], b},
-          {kWriteInplace},
-          {outputs[conv::kOut]});
+      auto li     = cudnn::GetLayoutInfo(layout);
+      if (li.channel_last ||
+          !cudnn::LegacyAddBias(ctx, li, outputs[conv::kOut], inputs[conv::kBias])) {
+        int k  = inputs[conv::kBias].shape_.Size();
+        auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
+        BinaryBroadcastRTCCompute{"add"}(  // NOLINT(whitespace/braces)
+            attrs,
+            ctx,
+            {outputs[conv::kOut], b},
+            {kWriteInplace},
+            {outputs[conv::kOut]});
+      }
     }
     if (!ok) {
       if (!param.cudnn_off)
@@ -137,17 +141,21 @@
                 cudnn::Exec<cudnn::ConvWgrad>(
                     ctx, conv_param, inputs[1 + conv::kData], inputs[0], outputs[conv::kWeight]));
     if (ok && !param.no_bias && req[conv::kBias] != kNullOp) {
-      auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
-      if (li.channel_last) {
-        // This kernel should be faster.
-        auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
-        AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
-      } else {
-        TShape axes{static_cast<int>(li.ChannelIdx())};
-        TShape small =
-            ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
-        ReduceAxesRTCComputeImpl(
-            ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
+      auto li     = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
+      auto add_to = req[conv::kBias] == kAddTo;
+      if (li.channel_last ||
+          !cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[conv::kBias], inputs[0])) {
+        if (li.channel_last) {
+          // This kernel should be faster.
+          auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
+          AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
+        } else {
+          TShape axes{static_cast<int>(li.ChannelIdx())};
+          TShape small = ReduceAxesShapeImpl(
+              inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
+          ReduceAxesRTCComputeImpl(
+              ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
+        }
       }
     }
     if (!ok) {
diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu
index ec97f82..a58c12d 100644
--- a/src/operator/nn/deconvolution.cu
+++ b/src/operator/nn/deconvolution.cu
@@ -56,14 +56,18 @@
     if (ok && !param.no_bias) {
       CHECK_EQ(inputs[deconv::kBias].shape_.ndim(), 1);
       auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
-      int k       = inputs[deconv::kBias].shape_.Size();
-      auto b      = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
-      BinaryBroadcastRTCCompute{"add"}(  // NOLINT(whitespace/braces)
-          attrs,
-          ctx,
-          {outputs[deconv::kOut], b},
-          {kWriteInplace},
-          {outputs[deconv::kOut]});
+      auto li     = cudnn::GetLayoutInfo(layout);
+      if (li.channel_last ||
+          !cudnn::LegacyAddBias(ctx, li, outputs[deconv::kOut], inputs[deconv::kBias])) {
+        int k  = inputs[deconv::kBias].shape_.Size();
+        auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
+        BinaryBroadcastRTCCompute{"add"}(  // NOLINT(whitespace/braces)
+            attrs,
+            ctx,
+            {outputs[deconv::kOut], b},
+            {kWriteInplace},
+            {outputs[deconv::kOut]});
+      }
     }
     if (!ok) {
       if (!param.cudnn_off)
@@ -115,17 +119,25 @@
           cudnn::Exec<cudnn::ConvWgrad>(
               ctx, conv_param, inputs[0], inputs[1 + deconv::kData], outputs[deconv::kWeight]));
     if (ok && !param.no_bias && req[deconv::kBias] != kNullOp) {
-      auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
-      if (li.channel_last) {
-        // This kernel should be faster.
-        auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
-        AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
-      } else {
-        TShape axes{static_cast<int>(li.ChannelIdx())};
-        TShape small =
-            ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
-        ReduceAxesRTCComputeImpl(
-            ctx, {inputs[0]}, {req[deconv::kBias]}, {outputs[deconv::kBias]}, small, "red::sum{}");
+      auto li     = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
+      auto add_to = req[conv::kBias] == kAddTo;
+      if (li.channel_last ||
+          !cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[deconv::kBias], inputs[0])) {
+        if (li.channel_last) {
+          // This kernel should be faster.
+          auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
+          AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
+        } else {
+          TShape axes{static_cast<int>(li.ChannelIdx())};
+          TShape small = ReduceAxesShapeImpl(
+              inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
+          ReduceAxesRTCComputeImpl(ctx,
+                                   {inputs[0]},
+                                   {req[deconv::kBias]},
+                                   {outputs[deconv::kBias]},
+                                   small,
+                                   "red::sum{}");
+        }
       }
     }
     if (!ok) {