[Master] Decouple OneDNN data structures in MXNet C++ API (#20624)

* NDArry file has been modified, there are a few chnages:

    1. OneDNN header was moved into *cc file
    2. OneDNN object are created on-the-fly: static_cast is needed

* Removed static_cast<*>

* Roll back ndarray file

* Roll Back - fwd namespace declaration

* Clang-format: auto-format

* Sanity fix

* Conv saity-fix

* Fix typo

* Fix tenary

* Review

* Review: test file

* Static_cast: removed redundant cast
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index ec48984..bed166a 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -36,9 +36,6 @@
 #include <memory>
 #include <string>
 #include <vector>
-#if MXNET_USE_ONEDNN == 1
-#include <dnnl.hpp>
-#endif
 #include "./base.h"
 #include "./engine.h"
 #include "./storage.h"
@@ -47,6 +44,10 @@
 #error "cxx11 was required for ndarray module"
 #endif
 
+namespace dnnl {
+struct memory;
+}  // namespace dnnl
+
 namespace mxnet {
 // enum for storage types
 namespace csr {
@@ -743,7 +744,7 @@
    * Create NDArray from dnnl memory descriptor.
    * mem_pd The dnnl memory descriptor to be created.
    */
-  explicit NDArray(const dnnl::memory::desc& md);
+  explicit NDArray(const void* md);
   /*
    * Test if the data is stored in one of special DNNL formats.
    */
@@ -771,13 +772,13 @@
    * This function returns dnnl::memory with the given primitive_desc
    * as long as the array size meets the required size in the given primitive_desc.
    */
-  const dnnl::memory* GetDNNLData(const dnnl::memory::desc& md) const;
+  const dnnl::memory* GetDNNLData(const void* md) const;
   /*
    * This function returns dnnl::memory with the given primitive_desc.
    * The returned dnnl::memory will have the same physical layout as
    * the given primitive_desc.
    */
-  const dnnl::memory* GetDNNLDataReorder(const dnnl::memory::desc& md) const;
+  const dnnl::memory* GetDNNLDataReorder(const void* md) const;
 
   /*
    * This function copies data from dnnl memory.
@@ -787,7 +788,7 @@
    * This function allocates memory for array and creates dnnl memory
    * with the specified format.
    */
-  dnnl::memory* CreateDNNLData(const dnnl::memory::desc& md);
+  dnnl::memory* CreateDNNLData(const void* md);
 
   /*
    * These are the async version of the methods above.
@@ -795,7 +796,7 @@
    * the array are complete.
    */
   void Reorder2DefaultAsync() const;
-  void DNNLDataReorderAsync(const dnnl::memory::desc& md) const;
+  void DNNLDataReorderAsync(const void* md) const;
 
   /*
    * This creates a new NDArray with the reordered data.
@@ -826,7 +827,7 @@
   /*!
    * \ Fix dnnl memory descriptor mismatch from NDArray.
    */
-  void UpdateDNNLMemDesc(const dnnl::memory::desc& desc);
+  void UpdateDNNLMemDesc(const void* desc);
 #endif
 
   /*!
@@ -1111,7 +1112,7 @@
     // save the result in shandle.
     void Reorder2Default();
     // Reroder data to a specified layout.
-    void DNNLDataReorder(const dnnl::memory::desc& md);
+    void DNNLDataReorder(const void* md);
     bool IsDNNL() const;
     bool IsDefault() const;
 #endif
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 66870da..605e705 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -37,7 +37,9 @@
 #include "../operator/tensor/matrix_op-inl.h"
 #include "../profiler/storage_profiler.h"
 #include "./ndarray_function.h"
-
+#if MXNET_USE_ONEDNN == 1
+#include <dnnl.hpp>
+#endif
 #if MXNET_USE_OPENCV
 #include <opencv2/opencv.hpp>
 #endif  // MXNET_USE_OPENCV
@@ -211,11 +213,11 @@
 
 #if MXNET_USE_ONEDNN == 1
 
-NDArray::NDArray(const dnnl::memory::desc& md)
-    : storage_type_(kDefaultStorage), autograd_entry_(nullptr) {
-  shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
-  dtype_ = get_mxnet_type(md.data.data_type);
-  ptr_   = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+NDArray::NDArray(const void* md_desc) : storage_type_(kDefaultStorage), autograd_entry_(nullptr) {
+  dnnl::memory::desc md = *static_cast<const dnnl::memory::desc*>(md_desc);
+  shape_                = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
+  dtype_                = get_mxnet_type(md.data.data_type);
+  ptr_                  = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
   ptr_->CheckAndAlloc(md.get_size());
   ptr_->dnnl_mem_ = std::make_shared<DNNLMemory>(md, ptr_->shandle.dptr);
 }
@@ -557,7 +559,8 @@
   dnnl_mem_ = nullptr;
 }
 
-void NDArray::Chunk::DNNLDataReorder(const dnnl::memory::desc& md) {
+void NDArray::Chunk::DNNLDataReorder(const void* mem_desc) {
+  const dnnl::memory::desc md = *static_cast<const dnnl::memory::desc*>(mem_desc);
   // If the memory already uses the specified layout, don't do anything.
   if (dnnl_mem_ != nullptr && dnnl_mem_->SameFormat(md))
     return;
@@ -623,7 +626,8 @@
   dnnl_mem_.reset(new DNNLMemory(data_md, shandle.dptr));
 }
 
-const dnnl::memory* NDArray::GetDNNLData(const dnnl::memory::desc& desc) const {
+const dnnl::memory* NDArray::GetDNNLData(const void* mem_desc) const {
+  const dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
   if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
     LOG(FATAL) << "The size of NDArray doesn't match the requested oneDNN memory desc";
     return nullptr;
@@ -639,7 +643,8 @@
   }
 }
 
-const dnnl::memory* NDArray::GetDNNLDataReorder(const dnnl::memory::desc& new_desc) const {
+const dnnl::memory* NDArray::GetDNNLDataReorder(const void* mem_desc) const {
+  dnnl::memory::desc new_desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
   CHECK(storage_type() == kDefaultStorage);
 
   const dnnl::memory* mem = GetDNNLData();
@@ -774,7 +779,8 @@
   return ret;
 }
 
-void NDArray::DNNLDataReorderAsync(const dnnl::memory::desc& desc) const {
+void NDArray::DNNLDataReorderAsync(const void* mem_desc) const {
+  dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
   std::vector<Engine::VarHandle> const_vars;
   std::vector<Engine::VarHandle> mutable_vars(1, this->var());
   NDArray tmp        = *this;
@@ -787,7 +793,7 @@
         // MXNet will try to reuse NDArray from memory planning, so we need to ensure
         // the NDArray is still holding the original trunk data.
         if (tmp.version() == version) {
-          tmp.ptr_->DNNLDataReorder(desc);
+          tmp.ptr_->DNNLDataReorder(&desc);
         }
         on_complete();
       },
@@ -860,7 +866,8 @@
   DNNLMemoryCopy(mem, this_mem);
 }
 
-dnnl::memory* NDArray::CreateDNNLData(const dnnl::memory::desc& desc) {
+dnnl::memory* NDArray::CreateDNNLData(const void* mem_desc) {
+  dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
   if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
     LOG(FATAL) << "The size of NDArray doesn't match the requested oneDNN memory desc. "
                << "oneDNN memory requests for " << desc.get_size() << " bytes, but got "
@@ -906,7 +913,8 @@
   return ptr_->dnnl_mem_->GetRaw();
 }
 
-void NDArray::UpdateDNNLMemDesc(const dnnl::memory::desc& desc) {
+void NDArray::UpdateDNNLMemDesc(const void* mem_desc) {
+  dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
   auto new_desc           = desc;
   auto this_dtype         = get_dnnl_type(dtype());
   new_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
diff --git a/src/operator/nn/dnnl/dnnl_act.cc b/src/operator/nn/dnnl/dnnl_act.cc
index 2cc8a34..4b51e45 100644
--- a/src/operator/nn/dnnl/dnnl_act.cc
+++ b/src/operator/nn/dnnl/dnnl_act.cc
@@ -255,8 +255,9 @@
   auto input_mem       = in_buffer.GetDNNLData();
   // We need to make sure the two inputs to eltwise_backward has the same memory
   // descriptor. Otherwise, the perf will suffer.
-  if (input_mem->get_desc() != diff_dst_memory->get_desc()) {
-    input_mem = in_buffer.GetDNNLDataReorder(diff_dst_memory->get_desc());
+  auto diff_dst_desc = diff_dst_memory->get_desc();
+  if (input_mem->get_desc() != diff_dst_desc) {
+    input_mem = in_buffer.GetDNNLDataReorder(&diff_dst_desc);
   }
 
   DNNLActBackward& bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
@@ -264,7 +265,8 @@
   dnnl_args_map_t args = {{DNNL_ARG_SRC, *input_mem}, {DNNL_ARG_DIFF_DST, *diff_dst_memory}};
   if (req[0] != kAddTo) {
     // req[0] is kWriteTo or kWriteInplace
-    auto diff_src_memory = const_cast<NDArray&>(in_grad).CreateDNNLData(bwd.bwd_pd.diff_src_desc());
+    auto bwd_pd_diff_src_desc = bwd.bwd_pd.diff_src_desc();
+    auto diff_src_memory      = const_cast<NDArray&>(in_grad).CreateDNNLData(&bwd_pd_diff_src_desc);
     args.insert({DNNL_ARG_DIFF_SRC, *diff_src_memory});
     stream->RegisterPrimArgs(bwd.GetBwd(), args);
     stream->Submit();
@@ -301,8 +303,9 @@
   auto input_mem       = in_buffer.GetDNNLData();
   // We need to make sure the two inputs to eltwise_backward has the same memory
   // descriptor. Otherwise, the perf will suffer.
-  if (input_mem->get_desc() != diff_dst_memory->get_desc())
-    input_mem = in_buffer.GetDNNLDataReorder(diff_dst_memory->get_desc());
+  auto diff_dst_desc = diff_dst_memory->get_desc();
+  if (input_mem->get_desc() != diff_dst_desc)
+    input_mem = in_buffer.GetDNNLDataReorder(&diff_dst_desc);
   DNNLActBackward& bwd          = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
   DNNLStream* stream            = DNNLStream::Get();
   dnnl_output_t diff_src_memory = CreateDNNLMem(output, bwd.bwd_pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index ec50da7..216a420 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -165,7 +165,7 @@
     auto tmp = TmpMemMgr::Get()->Alloc(desc);
     return dnnl_output_t(OutDataOp::AddBack, tmp);
   } else if (kWriteInplace == req && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
-    dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
+    dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(&desc);
     // mem is nullptr if out_arr is view and desc is DNNL format.
     // need to Reorder2Default before calling CreateDNNLMem
     CHECK(mem != nullptr);
@@ -174,7 +174,7 @@
     auto tmp = TmpMemMgr::Get()->Alloc(desc);
     return dnnl_output_t(OutDataOp::CopyBack, tmp);
   } else if (kWriteTo == req) {
-    dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
+    dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(&desc);
     if (nullptr == mem) {
       auto tmp = TmpMemMgr::Get()->Alloc(desc);
       return dnnl_output_t(OutDataOp::CopyBack, tmp);
@@ -197,7 +197,7 @@
   } else {
     dnnl::memory* mem = nullptr;
     if (IsDefaultFormat(desc)) {
-      mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
+      mem = const_cast<NDArray&>(out_arr).CreateDNNLData(&desc);
     }
     if (mem == nullptr) {
       auto tmp = TmpMemMgr::Get()->Alloc(desc);
@@ -214,7 +214,8 @@
   } else if (res.first == AddBack) {
     auto res_memory = res.second;
     auto target_pd  = arr.GetDNNLData()->get_desc();
-    auto mem        = arr.GetDNNLData(res.second->get_desc());
+    auto res_desc   = res.second->get_desc();
+    auto mem        = arr.GetDNNLData(&res_desc);
     if (mem == nullptr) {
       auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
       DNNLMemoryCopy(*res_memory, tmp_memory);
@@ -272,19 +273,19 @@
     LOG(FATAL) << "The weight array has an unsupported number of dimensions";
   }
   const auto md = dnnl::memory::desc{tz, type, format_tag};
-  return arr.GetDNNLData(md);
+  return arr.GetDNNLData(&md);
 }
 
 const dnnl::memory* GetWeights(const NDArray& arr,
                                const dnnl::memory::desc& target_desc,
                                int num_groups) {
-  const dnnl::memory* mem = arr.GetDNNLData(target_desc);
+  const dnnl::memory* mem = arr.GetDNNLData(&target_desc);
   // If the weight array already uses the target layout, simply return it directly.
   if (mem)
     return mem;
   mem = GetWeights(arr, num_groups);
   if (mem == nullptr)
-    mem = arr.GetDNNLDataReorder(target_desc);
+    mem = arr.GetDNNLDataReorder(&target_desc);
   if (mem->get_desc() == target_desc)
     return mem;
 
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
index 97f21ae..8152b6c 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
@@ -180,7 +180,8 @@
   auto& fwd     = GetBNForward<DType>(param, ctx, data_mem, flags);
 
   // for output memory
-  auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(fwd.GetPd().dst_desc());
+  auto fwd_dst_desc = fwd.GetPd().dst_desc();
+  auto out_mem      = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
 
   // mxnet will always use scale shift.
   // But if fix_gamma is true, then all scale elements will be set to 1.0f
@@ -387,10 +388,13 @@
   auto diff_mem = diff.GetDNNLData();
   // DNNL batchnorm should run on special layouts. If one of them isn't, we
   // should reorder them.
-  if (data.IsDefaultData())
-    data_mem = data.GetDNNLDataReorder(diff_mem->get_desc());
-  else if (diff.IsDefaultData())
-    diff_mem = diff.GetDNNLDataReorder(data_mem->get_desc());
+  if (data.IsDefaultData()) {
+    auto diff_desc = diff_mem->get_desc();
+    data_mem       = data.GetDNNLDataReorder(&diff_desc);
+  } else if (diff.IsDefaultData()) {
+    auto data_desc = data_mem->get_desc();
+    diff_mem       = diff.GetDNNLDataReorder(&data_desc);
+  }
   auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
   auto gradi_mem =
       CreateDNNLMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc b/src/operator/nn/dnnl/dnnl_convolution.cc
index 7a21290..f28f273 100644
--- a/src/operator/nn/dnnl/dnnl_convolution.cc
+++ b/src/operator/nn/dnnl/dnnl_convolution.cc
@@ -478,7 +478,8 @@
   auto& weight = in_data[conv::kWeight];
   bool no_bias = param.conv_param.no_bias && !param.dnnl_param.with_bn;
 
-  auto data_mem = data.GetDNNLDataReorder(fwd->GetPd().src_desc());
+  auto fwd_src_desc = fwd->GetPd().src_desc();
+  auto data_mem     = data.GetDNNLDataReorder(&fwd_src_desc);
   const dnnl::memory* weight_mem;
   if (ctx.is_train) {
     // TODO(zhengda) kvstore doesn't handle DNNL correctly. Let's reorder it to the default format
@@ -493,10 +494,12 @@
     if (weight.IsDefaultData()) {
       // We also need to modify the layout on the original weight array. The data conversion happens
       // after the weight array is used.
-      weight.DNNLDataReorderAsync(fwd->GetPd().weights_desc());
+      auto fwd_weight_desc = fwd->GetPd().weights_desc();
+      weight.DNNLDataReorderAsync(&fwd_weight_desc);
       weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group);
     } else {
-      weight_mem = weight.GetDNNLDataReorder(fwd->GetPd().weights_desc());
+      auto fwd_weight_desc = fwd->GetPd().weights_desc();
+      weight_mem           = weight.GetDNNLDataReorder(&fwd_weight_desc);
     }
   }
   dnnl_output_t out_mem;
@@ -599,8 +602,9 @@
   const ConvolutionParam& param = full_param.conv_param;
 
   CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  DNNLConvBackward& convBwd = GetConvBwd(full_param, data, weight, bias, out_grad);
-  auto out_grad_mem         = out_grad.GetDNNLDataReorder(convBwd.GetDataPd().diff_dst_desc());
+  DNNLConvBackward& convBwd   = GetConvBwd(full_param, data, weight, bias, out_grad);
+  auto convBwd_data_diff_desc = convBwd.GetDataPd().diff_dst_desc();
+  auto out_grad_mem           = out_grad.GetDNNLDataReorder(&convBwd_data_diff_desc);
   if (req[conv::kData]) {
     auto weight_mem = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group);
     auto in_grad_mem =
@@ -615,10 +619,13 @@
   auto req_weight = req.size() > conv::kWeight ? req.at(conv::kWeight) : kNullOp;
   auto req_bias   = req.size() > conv::kBias ? req.at(conv::kBias) : kNullOp;
   if (req_weight || req_bias) {
-    if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc())
-      out_grad_mem = out_grad.GetDNNLDataReorder(convBwd.GetWeightsPd().diff_dst_desc());
-    auto data_mem       = data.GetDNNLDataReorder(convBwd.GetWeightsPd().src_desc());
-    auto in_grad_weight = CreateDNNLWeightGrad(
+    if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc()) {
+      auto convBwd_weight_diff_desc = convBwd.GetWeightsPd().diff_dst_desc();
+      out_grad_mem                  = out_grad.GetDNNLDataReorder(&convBwd_weight_diff_desc);
+    }
+    auto convBwd_weight_src_desc = convBwd.GetWeightsPd().src_desc();
+    auto data_mem                = data.GetDNNLDataReorder(&convBwd_weight_src_desc);
+    auto in_grad_weight          = CreateDNNLWeightGrad(
         in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]);
 
     dnnl_args_map_t net_args = {{DNNL_ARG_DIFF_DST, *out_grad_mem},
diff --git a/src/operator/nn/dnnl/dnnl_copy.cc b/src/operator/nn/dnnl/dnnl_copy.cc
index 0fa9dc1..16cbabd 100644
--- a/src/operator/nn/dnnl/dnnl_copy.cc
+++ b/src/operator/nn/dnnl/dnnl_copy.cc
@@ -43,10 +43,11 @@
     TmpMemMgr::Get()->Init(ctx.requested[0]);
     // We should try and force the input memory has the same format
     // as the input output. If not, we'll have to reorder memory.
-    auto out_mem = out_data.GetDNNLData();
-    in_mem       = in_data.GetDNNLData(out_mem->get_desc());
+    auto out_mem      = out_data.GetDNNLData();
+    auto out_mem_desc = out_mem->get_desc();
+    in_mem            = in_data.GetDNNLData(&out_mem_desc);
     if (in_mem == nullptr)
-      in_mem = in_data.GetDNNLDataReorder(out_mem->get_desc());
+      in_mem = in_data.GetDNNLDataReorder(&out_mem_desc);
     DNNLSum(*out_mem, *in_mem, *out_mem);
   } else {
     const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
index 1078423..a1ac551 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
@@ -82,7 +82,8 @@
         temp.data_type(),
         static_cast<dnnl::memory::format_tag>(GetDefaultFormat(temp.data.ndims)));
   }
-  const_cast<NDArray&>(arr).UpdateDNNLMemDesc(IOLogicalSwapDesc(desc, num_group));
+  auto iOLogicalSwapDesc = IOLogicalSwapDesc(desc, num_group);
+  const_cast<NDArray&>(arr).UpdateDNNLMemDesc(&iOLogicalSwapDesc);
 }
 
 // Version of GetWeightsDesc for deconvolution (with swap)
@@ -149,7 +150,8 @@
 }
 
 inline const dnnl::memory* DNNLDeconvFwd::DataMem(const NDArray& data) const {
-  return data.GetDNNLDataReorder(fwd_pd->src_desc());
+  auto fwd_src_desc = fwd_pd->src_desc();
+  return data.GetDNNLDataReorder(&fwd_src_desc);
 }
 
 inline const dnnl::memory* DNNLDeconvFwd::WeightsMem(const uint32_t num_group,
@@ -275,7 +277,8 @@
 }
 
 inline const dnnl::memory* DNNLDeconvBwd::DataMem(const NDArray& data) const {
-  return data.GetDNNLDataReorder(bwd_weights_pd->src_desc());
+  auto bwd_weight_src_desc = bwd_weights_pd->src_desc();
+  return data.GetDNNLDataReorder(&bwd_weight_src_desc);
 }
 
 inline const dnnl::memory* DNNLDeconvBwd::WeightsMem(const uint32_t num_group,
@@ -284,14 +287,16 @@
 }
 
 inline const dnnl::memory* DNNLDeconvBwd::OutGradMem(const NDArray& out_grad) const {
-  return out_grad.GetDNNLDataReorder(bwd_data_pd->diff_dst_desc());
+  auto bwd_data_diff_desc = bwd_data_pd->diff_dst_desc();
+  return out_grad.GetDNNLDataReorder(&bwd_data_diff_desc);
 }
 
 inline const dnnl::memory* DNNLDeconvBwd::OutGradMem(const NDArray& out_grad,
                                                      const dnnl::memory* const out_grad_mem) const {
-  return (out_grad_mem && out_grad_mem->get_desc() == bwd_weights_pd->diff_dst_desc()) ?
+  auto bwd_weight_diff_desc = bwd_weights_pd->diff_dst_desc();
+  return (out_grad_mem && out_grad_mem->get_desc() == bwd_weight_diff_desc) ?
              out_grad_mem :
-             out_grad.GetDNNLDataReorder(bwd_weights_pd->diff_dst_desc());
+             out_grad.GetDNNLDataReorder(&bwd_weight_diff_desc);
 }
 
 inline dnnl_output_t DNNLDeconvBwd::DataGradMem(const OpReqType req,
@@ -308,7 +313,7 @@
   // swap, weights_md will have a default format
   const auto& weights_md = bwd_weights_pd->diff_weights_desc();
   if (req == OpReqType::kWriteTo && IsDefaultFormat(IOLogicalSwapDesc(weights_md, num_group))) {
-    return {OutDataOp::Noop, const_cast<NDArray&>(weights_grad).CreateDNNLData(weights_md)};
+    return {OutDataOp::Noop, const_cast<NDArray&>(weights_grad).CreateDNNLData(&weights_md)};
   }
   return CreateDNNLWeightGrad(weights_grad, weights_md, req);
 }
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index 9140d19..9487574 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -112,7 +112,8 @@
     if (weights.IsDefaultData()) {
       // We also need to modify the layout on the original weights array.
       // The data conversion happens after the weights array is used.
-      weights.DNNLDataReorderAsync(IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
+      auto logical_swap_desc = IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group);
+      weights.DNNLDataReorderAsync(&logical_swap_desc);
     } else {
       CHECK(weights.GetDNNLData()->get_desc() ==
             IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected.cc b/src/operator/nn/dnnl/dnnl_fully_connected.cc
index 6f04b19..bf58135 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected.cc
+++ b/src/operator/nn/dnnl/dnnl_fully_connected.cc
@@ -186,10 +186,10 @@
                               const std::vector<OpReqType>& req,
                               const std::vector<NDArray>& out_data) {
   TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
-  NDArray weight = in_data[fullc::kWeight];
-  NDArray data   = in_data[fullc::kData];
-
-  auto data_mem = data.GetDNNLDataReorder(fwd->fwd_pd.src_desc());
+  NDArray weight    = in_data[fullc::kWeight];
+  NDArray data      = in_data[fullc::kData];
+  auto fwd_src_desc = fwd->fwd_pd.src_desc();
+  auto data_mem     = data.GetDNNLDataReorder(&fwd_src_desc);
   const dnnl::memory* weight_mem;
   if (ctx.is_train) {
     if (weight.IsDNNLData()) {
@@ -199,7 +199,8 @@
   } else {
     weight_mem = weight.GetDNNLData();
     if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
-      weight.DNNLDataReorderAsync(fwd->fwd_pd.weights_desc());
+      auto fwd_weight_desc = fwd->fwd_pd.weights_desc();
+      weight.DNNLDataReorderAsync(&fwd_weight_desc);
       weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
     }
   }
@@ -212,7 +213,8 @@
       {DNNL_ARG_DST, *out_mem.second},
   };
   if (!full_param.default_param.no_bias) {
-    auto bias_mem       = in_data[fullc::kBias].GetDNNLDataReorder(fwd->fwd_pd.bias_desc());
+    auto fwd_bias_desc  = fwd->fwd_pd.bias_desc();
+    auto bias_mem       = in_data[fullc::kBias].GetDNNLDataReorder(&fwd_bias_desc);
     args[DNNL_ARG_BIAS] = *bias_mem;
   }
   DNNLStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
@@ -286,9 +288,11 @@
   if (req[fullc::kWeight]) {
     dnnl::inner_product_backward_weights::primitive_desc ipBwdWeights_pd = GetFCBwdWeights(
         data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], out_grad, fwd_pd);
-    auto out_grad_mem   = out_grad.GetDNNLDataReorder(ipBwdWeights_pd.diff_dst_desc());
-    auto data_mem       = data.GetDNNLDataReorder(ipBwdWeights_pd.src_desc());
-    auto in_grad_weight = CreateDNNLWeightGrad(
+    auto ipBwdWeights_diff_dst_desc = ipBwdWeights_pd.diff_dst_desc();
+    auto ipBwdWeights_src_desc      = ipBwdWeights_pd.src_desc();
+    auto out_grad_mem               = out_grad.GetDNNLDataReorder(&ipBwdWeights_diff_dst_desc);
+    auto data_mem                   = data.GetDNNLDataReorder(&ipBwdWeights_src_desc);
+    auto in_grad_weight             = CreateDNNLWeightGrad(
         in_grad[fullc::kWeight], ipBwdWeights_pd.diff_weights_desc(), req[fullc::kWeight]);
     dnnl_args_map_t args = {
         {DNNL_ARG_DIFF_DST, *out_grad_mem},
@@ -312,8 +316,10 @@
   if (req[fullc::kData]) {
     dnnl::inner_product_backward_data::primitive_desc ipBwdData_pd =
         GetFCBwdData(data, weight, out_grad, fwd_pd);
-    auto out_grad_mem = out_grad.GetDNNLDataReorder(ipBwdData_pd.diff_dst_desc());
-    auto weight_mem   = weight.GetDNNLDataReorder(ipBwdData_pd.weights_desc());
+    auto ipBwdData_diff_dst_desc = ipBwdData_pd.diff_dst_desc();
+    auto ipBwdData_weight_desc   = ipBwdData_pd.weights_desc();
+    auto out_grad_mem            = out_grad.GetDNNLDataReorder(&ipBwdData_diff_dst_desc);
+    auto weight_mem              = weight.GetDNNLDataReorder(&ipBwdData_weight_desc);
     auto in_grad_mem =
         CreateDNNLMem(in_grad[fullc::kData], ipBwdData_pd.diff_src_desc(), req[fullc::kData]);
     dnnl_args_map_t args = {{DNNL_ARG_DIFF_DST, *out_grad_mem},
diff --git a/src/operator/nn/dnnl/dnnl_layer_norm.cc b/src/operator/nn/dnnl/dnnl_layer_norm.cc
index 4108a62..65eb0a3 100644
--- a/src/operator/nn/dnnl/dnnl_layer_norm.cc
+++ b/src/operator/nn/dnnl/dnnl_layer_norm.cc
@@ -134,10 +134,11 @@
                                const std::vector<NDArray>& outputs) const {
   auto mean_var_md = GetMeanVarDesc(get_dnnl_type(outputs[layernorm::kMean].dtype()),
                                     outputs[layernorm::kMean].shape());
-  auto mean_mem    = dnnl_output_t(
-      OutDataOp::Noop, const_cast<NDArray&>(outputs[layernorm::kMean]).CreateDNNLData(mean_var_md));
+  auto mean_mem =
+      dnnl_output_t(OutDataOp::Noop,
+                    const_cast<NDArray&>(outputs[layernorm::kMean]).CreateDNNLData(&mean_var_md));
   auto variance_mem = dnnl_output_t(
-      OutDataOp::Noop, const_cast<NDArray&>(outputs[layernorm::kStd]).CreateDNNLData(mean_var_md));
+      OutDataOp::Noop, const_cast<NDArray&>(outputs[layernorm::kStd]).CreateDNNLData(&mean_var_md));
 
   auto output_mem      = CreateDNNLMem(outputs[layernorm::kOut], fwd_pd->dst_desc(), req);
   auto scale_shift_mem = GetScaleShiftMem(inputs[layernorm::kGamma], inputs[layernorm::kBeta]);
@@ -183,7 +184,8 @@
                                const std::vector<OpReqType>& req) const {
   auto scale_shift_mem =
       GetScaleShiftMem(inputs[layernorm::kBwdGamma], inputs[layernorm::kBwdBeta]);
-  auto diff_weights_ndarray = NDArray(scale_shift_mem.get_desc());
+  auto scale_shift_mem_desc = scale_shift_mem.get_desc();
+  auto diff_weights_ndarray = NDArray(&scale_shift_mem_desc);
   const auto bytes          = inputs[layernorm::kBwdGamma].shape()[0] *
                      mshadow::mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype());
   const auto diff_weights_ndaray_data_ptr_plus_bytes = reinterpret_cast<void*>(
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index 9243440..be1abdb 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -127,9 +127,10 @@
   int axis                  = CheckAxis(param.axis, in_data.shape().ndim());
   auto fwd                  = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);
 
-  auto in_mem        = in_data.GetDNNLData();
-  auto out_mem       = out_data.GetDNNLData(fwd.pd.dst_desc());
-  DNNLStream* stream = DNNLStream::Get();
+  auto in_mem          = in_data.GetDNNLData();
+  auto fwd_pd_dst_desc = fwd.pd.dst_desc();
+  auto out_mem         = out_data.GetDNNLData(&fwd_pd_dst_desc);
+  DNNLStream* stream   = DNNLStream::Get();
   stream->RegisterPrimArgs(fwd.GetFwd(), {{DNNL_ARG_SRC, *in_mem}, {DNNL_ARG_DST, *out_mem}});
   stream->Submit();
 }
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax.cc b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
index 789a6bf..a2a9c58 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
@@ -169,10 +169,11 @@
   p.axis        = param.axis;
   p.temperature = param.temperature;
 
-  auto softmax_tensors = DNNLSoftmaxFwd::Tensors(output, output);
-  auto softmax_op      = DNNLSoftmaxFwd::GetCached(p, softmax_tensors, is_train);
-  auto softmax_out_mem = output.GetDNNLData(softmax_op.softmax_pd->dst_desc());
-  const auto input_mem = input.GetDNNLData();
+  auto softmax_tensors     = DNNLSoftmaxFwd::Tensors(output, output);
+  auto softmax_op          = DNNLSoftmaxFwd::GetCached(p, softmax_tensors, is_train);
+  auto softmax_op_dst_desc = softmax_op.softmax_pd->dst_desc();
+  auto softmax_out_mem     = output.GetDNNLData(&softmax_op_dst_desc);
+  const auto input_mem     = input.GetDNNLData();
 
   // 1. C) out = input * out
   stream->RegisterPrimArgs(this->primitives->mask_input,
diff --git a/src/operator/nn/dnnl/dnnl_pooling.cc b/src/operator/nn/dnnl/dnnl_pooling.cc
index 85e72bb..c2dfe36 100644
--- a/src/operator/nn/dnnl/dnnl_pooling.cc
+++ b/src/operator/nn/dnnl/dnnl_pooling.cc
@@ -431,10 +431,11 @@
 
   TmpMemMgr::Get()->Init(ctx.requested[0]);
 
-  auto& bwd         = GetPoolingBwd(param, *in_data, in_grad, out_grad, param.IsAdaptivePooling());
-  auto diff_dst_mem = out_grad.GetDNNLDataReorder(bwd.pd.diff_dst_desc());
-  auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]);
-  dnnl_args_map_t args = {
+  auto& bwd = GetPoolingBwd(param, *in_data, in_grad, out_grad, param.IsAdaptivePooling());
+  auto bwd_diff_dst_desc = bwd.pd.diff_dst_desc();
+  auto diff_dst_mem      = out_grad.GetDNNLDataReorder(&bwd_diff_dst_desc);
+  auto diff_src_mem      = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]);
+  dnnl_args_map_t args   = {
       {DNNL_ARG_DIFF_DST, *diff_dst_mem},
       {DNNL_ARG_DIFF_SRC, *diff_src_mem.second},
   };
diff --git a/src/operator/nn/dnnl/dnnl_reduce.cc b/src/operator/nn/dnnl/dnnl_reduce.cc
index f486c2f..a9be0af 100644
--- a/src/operator/nn/dnnl/dnnl_reduce.cc
+++ b/src/operator/nn/dnnl/dnnl_reduce.cc
@@ -225,7 +225,8 @@
     auto out_mem = dnnl::memory(reduce_pd->dst_desc(), engine, tensors.out.data().dptr<float>());
     stream->RegisterPrimArgs(*reduce_fwd, {{DNNL_ARG_SRC, *input_mem}, {DNNL_ARG_DST, out_mem}});
   } else {
-    auto out_mem = tensors.out.GetDNNLData(reduce_pd->dst_desc());
+    auto desc    = reduce_pd->dst_desc();
+    auto out_mem = tensors.out.GetDNNLData(&desc);
     stream->RegisterPrimArgs(*reduce_fwd, {{DNNL_ARG_SRC, *input_mem}, {DNNL_ARG_DST, *out_mem}});
   }
   stream->Submit();
diff --git a/src/operator/nn/dnnl/dnnl_softmax-inl.h b/src/operator/nn/dnnl/dnnl_softmax-inl.h
index c681f33..42558c6 100644
--- a/src/operator/nn/dnnl/dnnl_softmax-inl.h
+++ b/src/operator/nn/dnnl/dnnl_softmax-inl.h
@@ -81,7 +81,6 @@
   std::shared_ptr<linear_t> temperature_fwd;
 };
 
-
 class DNNLSoftmaxBwd {
  public:
   struct Tensors {
@@ -107,7 +106,6 @@
   std::shared_ptr<linear_t> temperature_fwd;
 };
 
-
 }  // namespace op
 }  // namespace mxnet
 #endif
diff --git a/src/operator/nn/dnnl/dnnl_softmax.cc b/src/operator/nn/dnnl/dnnl_softmax.cc
index 93fe557..73321a1 100644
--- a/src/operator/nn/dnnl/dnnl_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax.cc
@@ -139,8 +139,9 @@
 void DNNLSoftmaxFwd::Execute(const Tensors& tensors) const {
   DNNLStream* stream = DNNLStream::Get();
 
-  auto original_input_mem = tensors.data.GetDNNLData();
-  const auto out_mem      = tensors.out.GetDNNLData(softmax_pd->dst_desc());
+  auto original_input_mem  = tensors.data.GetDNNLData();
+  auto softmax_pd_dst_desc = softmax_pd->dst_desc();
+  const auto out_mem       = tensors.out.GetDNNLData(&softmax_pd_dst_desc);
 
   dnnl::memory* softmax_input_mem;
   if (temperature_pd) {
diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc
index c2335b9..7d5ca4a 100644
--- a/src/operator/nn/dnnl/dnnl_where.cc
+++ b/src/operator/nn/dnnl/dnnl_where.cc
@@ -171,9 +171,13 @@
   const auto& cpu_engine = CpuEngine::Get()->get_engine();
   const auto& cpu_stream = ctx.get_stream<cpu>();
 
-  const auto& cnd_tensor = tensors.condition.GetDNNLDataReorder(binary_eq_zero_pd.src0_desc());
-  const auto& lhs_tensor = tensors.left.GetDNNLDataReorder(binary_mul_l_pd.src0_desc());
-  const auto& rhs_tensor = tensors.right.GetDNNLDataReorder(binary_mul_r_pd.src0_desc());
+  auto binary_eq_zero_pd_desc = binary_eq_zero_pd.src0_desc();
+  auto binary_mul_l_pd_desc   = binary_mul_l_pd.src0_desc();
+  auto binary_mul_r_pd_desc   = binary_mul_r_pd.src0_desc();
+
+  const auto& cnd_tensor = tensors.condition.GetDNNLDataReorder(&binary_eq_zero_pd_desc);
+  const auto& lhs_tensor = tensors.left.GetDNNLDataReorder(&binary_mul_l_pd_desc);
+  const auto& rhs_tensor = tensors.right.GetDNNLDataReorder(&binary_mul_r_pd_desc);
 
   mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, binary_sum_pd.dst_desc(), req[0]);
 
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
index 3f13775..ba3b2f8 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
@@ -114,7 +114,8 @@
   }
 
   const NDArray& out = outputs[batchnorm::kOut];
-  auto out_mem       = const_cast<NDArray&>(out).CreateDNNLData(fwd.GetPd().dst_desc());
+  auto fwd_dst_desc  = fwd.GetPd().dst_desc();
+  auto out_mem       = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
   dnnl_args_map_t net_args;
   net_args[DNNL_ARG_SRC]         = *data_mem;
   net_args[DNNL_ARG_SCALE_SHIFT] = weight_mem;
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_conv.cc b/src/operator/quantization/dnnl/dnnl_quantized_conv.cc
index 158d0ea..85b915e 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_conv.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_conv.cc
@@ -45,13 +45,14 @@
   DNNLConvFullParam full_param;
   full_param.conv_param = param;
   full_param.dnnl_param.Init(std::unordered_map<std::string, std::string>());
-  auto& fwd     = GetConvFwd(full_param,
+  auto& fwd         = GetConvFwd(full_param,
                          ctx.is_train,
                          in_data[conv::kData],
                          in_data[conv::kWeight],
                          param.no_bias ? nullptr : &in_data[conv::kBias],
                          out_data[conv::kOut]);
-  auto data_mem = in_data[conv::kData].GetDNNLDataReorder(fwd.GetPd().src_desc());
+  auto fwd_src_desc = fwd.GetPd().src_desc();
+  auto data_mem     = in_data[conv::kData].GetDNNLDataReorder(&fwd_src_desc);
   const dnnl::memory* weight_mem;
   // For inference, we want to reorder the weight array so we don't need to
   // reorder data every time.
@@ -59,15 +60,17 @@
     // We also need to modify the layout on the original weight array.
     // Don't switch below sequence because naive engine will executes
     // pushAsync synchronously.
-    weight.DNNLDataReorderAsync(fwd.GetPd().weights_desc());
-    weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
+    auto fwd_weight_desc = fwd.GetPd().weights_desc();
+    weight.DNNLDataReorderAsync(&fwd_weight_desc);
+    weight_mem = GetWeights(weight, fwd_weight_desc, param.num_group);
   } else {
     weight_mem = weight.GetDNNLData();
   }
   auto out_mem = CreateDNNLMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]);
   dnnl_args_map_t net_args;
   if (!param.no_bias) {
-    const dnnl::memory* bias_mem = in_data[conv::kBias].GetDNNLDataReorder(fwd.GetPd().bias_desc());
+    auto fwd_bias_desc           = fwd.GetPd().bias_desc();
+    const dnnl::memory* bias_mem = in_data[conv::kBias].GetDNNLDataReorder(&fwd_bias_desc);
     net_args.insert({DNNL_ARG_BIAS, *bias_mem});
   }
   net_args.insert({DNNL_ARG_SRC, *data_mem});
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc b/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc
index 7746129..ca75f54 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc
@@ -90,15 +90,17 @@
   auto& fwd =
       GetFCFwd(param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md);
 
-  auto data_mem                  = in_data[fullc::kData].GetDNNLDataReorder(fwd.fwd_pd.src_desc());
+  auto fwd_src_desc              = fwd.fwd_pd.src_desc();
+  auto data_mem                  = in_data[fullc::kData].GetDNNLDataReorder(&fwd_src_desc);
   const dnnl::memory* weight_mem = nullptr;
 
   if (weight.IsDefaultData()) {
     // We also need to modify the layout on the original weight array.
     // Don't switch below sequence because naive engine will executes
     // pushAsync synchronously.
-    weight.DNNLDataReorderAsync(fwd.fwd_pd.weights_desc());
-    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
+    auto fwd_weight_desc = fwd.fwd_pd.weights_desc();
+    weight.DNNLDataReorderAsync(&fwd_weight_desc);
+    weight_mem = GetWeights(weight, fwd_weight_desc, 1);
   } else {
     weight_mem = weight.GetDNNLData();
     CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
@@ -113,7 +115,8 @@
 
   const dnnl::memory* bias_mem = nullptr;
   if (!param.no_bias) {
-    bias_mem            = quantized_bias.GetDNNLDataReorder(fwd.fwd_pd.bias_desc());
+    auto fwd_bias_desc  = fwd.fwd_pd.bias_desc();
+    bias_mem            = quantized_bias.GetDNNLDataReorder(&fwd_bias_desc);
     args[DNNL_ARG_BIAS] = *bias_mem;
   }
 
diff --git a/src/operator/subgraph/dnnl/dnnl_common.h b/src/operator/subgraph/dnnl/dnnl_common.h
index 68f10c0..3db25aa 100644
--- a/src/operator/subgraph/dnnl/dnnl_common.h
+++ b/src/operator/subgraph/dnnl/dnnl_common.h
@@ -103,7 +103,7 @@
                                           const std::vector<float>& weight_scales,
                                           const bool submit = true) {
   DNNLStream* stream             = DNNLStream::Get();
-  const auto new_weight          = NDArray(weight_md);
+  const auto new_weight          = NDArray(&weight_md);
   const auto conv_weights_memory = new_weight.GetDNNLData();
   dnnl::primitive_attr weight_attr;
   if (weight_scales.size()) {
@@ -124,7 +124,7 @@
     for (size_t c = 0; c < weight_scales.size(); ++c) {
       bias_scales[c] = weight_scales[c] * data_scale;
     }
-    new_bias                    = NDArray(*bias_md);
+    new_bias                    = NDArray(bias_md);
     const auto conv_bias_memory = new_bias.GetDNNLData();
     const int bias_mask         = (bias_scales.size()) == 1 ? 0 : 1;
     dnnl::primitive_attr bias_attr;
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc
index 2712ee9..2627460 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -358,7 +358,7 @@
     const auto& out_mem_desc = output_mem->get_desc();
     const auto& dst_mem_desc = fwd_->GetPd().dst_desc();
     if (out_mem_desc != dst_mem_desc) {
-      auto tmp_out_mem       = output.GetDNNLDataReorder(fwd_->GetPd().dst_desc());
+      auto tmp_out_mem       = output.GetDNNLDataReorder(&dst_mem_desc);
       auto data_md           = dst_mem_desc;
       data_md.data.data_type = static_cast<dnnl_data_type_t>(out_mem_desc.data.data_type);
       dnnl_mem_ptr new_out_mem(
@@ -370,10 +370,12 @@
   }
 
   if (dnnl_param.quantized) {
-    auto data_mem       = data.GetDNNLDataReorder(fwd_->GetPd().src_desc());
-    dnnl::memory* mem   = output.CreateDNNLData(fwd_->GetPd().dst_desc());
-    args_[DNNL_ARG_SRC] = *data_mem;
-    args_[DNNL_ARG_DST] = *mem;
+    auto fwd_src_desc    = fwd_->GetPd().src_desc();
+    auto data_mem        = data.GetDNNLDataReorder(&fwd_src_desc);
+    auto fwd_pd_dst_desc = fwd_->GetPd().dst_desc();
+    dnnl::memory* mem    = output.CreateDNNLData(&fwd_pd_dst_desc);
+    args_[DNNL_ARG_SRC]  = *data_mem;
+    args_[DNNL_ARG_DST]  = *mem;
     DNNLStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
     DNNLStream::Get()->Submit();
   } else {
@@ -391,8 +393,9 @@
     *outputs[kMax].data().dptr<float>() = cached_output_max_;
   }
   if (dnnl_param.with_sum) {
-    auto out = const_cast<NDArray&>(outputs[kOut]);
-    out.UpdateDNNLMemDesc(fwd_->GetPd().dst_desc());
+    auto out          = const_cast<NDArray&>(outputs[kOut]);
+    auto fwd_dst_desc = fwd_->GetPd().dst_desc();
+    out.UpdateDNNLMemDesc(&fwd_dst_desc);
   }
 }
 
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc
index f9d0d57..a5c199f 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -417,7 +417,7 @@
       const auto def_weight_mem = static_cast<const dnnl::memory*>(weight.GetDNNLData());
       if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
         auto weight_desc       = fwd_->fwd_pd.weights_desc();
-        cached_weight_         = NDArray(weight_desc);
+        cached_weight_         = NDArray(&weight_desc);
         auto cached_weight_mem = static_cast<const dnnl::memory*>(cached_weight_.GetDNNLData());
         std::unordered_map<int, dnnl::memory> args(
             {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, *cached_weight_mem}});
@@ -442,7 +442,7 @@
     const auto& out_mem_desc = output_mem->get_desc();
     auto dst_mem_desc        = fwd_->fwd_pd.dst_desc();
     if (out_mem_desc != dst_mem_desc) {
-      auto tmp_out_mem            = output.GetDNNLDataReorder(dst_mem_desc);
+      auto tmp_out_mem            = output.GetDNNLDataReorder(&dst_mem_desc);
       dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
       dnnl_mem_ptr new_out_mem(new dnnl::memory(
           dst_mem_desc, CpuEngine::Get()->get_engine(), output_mem->get_data_handle()));
diff --git a/tests/cpp/include/test_dnnl.h b/tests/cpp/include/test_dnnl.h
index 7172b0b..166074e 100644
--- a/tests/cpp/include/test_dnnl.h
+++ b/tests/cpp/include/test_dnnl.h
@@ -86,7 +86,7 @@
                                  bool is_rand = false,
                                  int max      = 50) {
   InitDefaultArray(arr, is_rand, max);
-  arr->DNNLDataReorderAsync(desc);
+  arr->DNNLDataReorderAsync(&desc);
   arr->WaitToRead();
 }
 
diff --git a/tests/cpp/operator/dnnl_test.cc b/tests/cpp/operator/dnnl_test.cc
index 510d368..2874df1 100644
--- a/tests/cpp/operator/dnnl_test.cc
+++ b/tests/cpp/operator/dnnl_test.cc
@@ -139,7 +139,7 @@
     InitDefaultArray(&arr);
     for (auto md : mds) {
       if (s.Size() == md.get_size() / sizeof(mshadow::default_real_t)) {
-        const dnnl::memory* mem = arr.GetDNNLDataReorder(md);
+        const dnnl::memory* mem = arr.GetDNNLDataReorder(&md);
         printf("reorder from (");
         for (size_t i = 0; i < s.ndim(); i++)
           printf("%ld, ", s[i]);
@@ -171,7 +171,7 @@
         InitDNNLArray(&arr, md);
         for (auto to_md : mds) {
           if (to_md.get_size() / sizeof(mshadow::default_real_t) == s.Size()) {
-            const dnnl::memory* mem = arr.GetDNNLDataReorder(to_md);
+            const dnnl::memory* mem = arr.GetDNNLDataReorder(&to_md);
             printf("reorder from (");
             for (size_t i = 0; i < s.ndim(); i++)
               printf("%ld, ", s[i]);