[Master] Decouple OneDNN data structures in MXNet C++ API (#20624)
* NDArry file has been modified, there are a few chnages:
1. OneDNN header was moved into *cc file
2. OneDNN object are created on-the-fly: static_cast is needed
* Removed static_cast<*>
* Roll back ndarray file
* Roll Back - fwd namespace declaration
* Clang-format: auto-format
* Sanity fix
* Conv saity-fix
* Fix typo
* Fix tenary
* Review
* Review: test file
* Static_cast: removed redundant cast
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index ec48984..bed166a 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -36,9 +36,6 @@
#include <memory>
#include <string>
#include <vector>
-#if MXNET_USE_ONEDNN == 1
-#include <dnnl.hpp>
-#endif
#include "./base.h"
#include "./engine.h"
#include "./storage.h"
@@ -47,6 +44,10 @@
#error "cxx11 was required for ndarray module"
#endif
+namespace dnnl {
+struct memory;
+} // namespace dnnl
+
namespace mxnet {
// enum for storage types
namespace csr {
@@ -743,7 +744,7 @@
* Create NDArray from dnnl memory descriptor.
* mem_pd The dnnl memory descriptor to be created.
*/
- explicit NDArray(const dnnl::memory::desc& md);
+ explicit NDArray(const void* md);
/*
* Test if the data is stored in one of special DNNL formats.
*/
@@ -771,13 +772,13 @@
* This function returns dnnl::memory with the given primitive_desc
* as long as the array size meets the required size in the given primitive_desc.
*/
- const dnnl::memory* GetDNNLData(const dnnl::memory::desc& md) const;
+ const dnnl::memory* GetDNNLData(const void* md) const;
/*
* This function returns dnnl::memory with the given primitive_desc.
* The returned dnnl::memory will have the same physical layout as
* the given primitive_desc.
*/
- const dnnl::memory* GetDNNLDataReorder(const dnnl::memory::desc& md) const;
+ const dnnl::memory* GetDNNLDataReorder(const void* md) const;
/*
* This function copies data from dnnl memory.
@@ -787,7 +788,7 @@
* This function allocates memory for array and creates dnnl memory
* with the specified format.
*/
- dnnl::memory* CreateDNNLData(const dnnl::memory::desc& md);
+ dnnl::memory* CreateDNNLData(const void* md);
/*
* These are the async version of the methods above.
@@ -795,7 +796,7 @@
* the array are complete.
*/
void Reorder2DefaultAsync() const;
- void DNNLDataReorderAsync(const dnnl::memory::desc& md) const;
+ void DNNLDataReorderAsync(const void* md) const;
/*
* This creates a new NDArray with the reordered data.
@@ -826,7 +827,7 @@
/*!
* \ Fix dnnl memory descriptor mismatch from NDArray.
*/
- void UpdateDNNLMemDesc(const dnnl::memory::desc& desc);
+ void UpdateDNNLMemDesc(const void* desc);
#endif
/*!
@@ -1111,7 +1112,7 @@
// save the result in shandle.
void Reorder2Default();
// Reroder data to a specified layout.
- void DNNLDataReorder(const dnnl::memory::desc& md);
+ void DNNLDataReorder(const void* md);
bool IsDNNL() const;
bool IsDefault() const;
#endif
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 66870da..605e705 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -37,7 +37,9 @@
#include "../operator/tensor/matrix_op-inl.h"
#include "../profiler/storage_profiler.h"
#include "./ndarray_function.h"
-
+#if MXNET_USE_ONEDNN == 1
+#include <dnnl.hpp>
+#endif
#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
#endif // MXNET_USE_OPENCV
@@ -211,11 +213,11 @@
#if MXNET_USE_ONEDNN == 1
-NDArray::NDArray(const dnnl::memory::desc& md)
- : storage_type_(kDefaultStorage), autograd_entry_(nullptr) {
- shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
- dtype_ = get_mxnet_type(md.data.data_type);
- ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
+NDArray::NDArray(const void* md_desc) : storage_type_(kDefaultStorage), autograd_entry_(nullptr) {
+ dnnl::memory::desc md = *static_cast<const dnnl::memory::desc*>(md_desc);
+ shape_ = mxnet::TShape(md.data.dims, md.data.dims + md.data.ndims);
+ dtype_ = get_mxnet_type(md.data.data_type);
+ ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->CheckAndAlloc(md.get_size());
ptr_->dnnl_mem_ = std::make_shared<DNNLMemory>(md, ptr_->shandle.dptr);
}
@@ -557,7 +559,8 @@
dnnl_mem_ = nullptr;
}
-void NDArray::Chunk::DNNLDataReorder(const dnnl::memory::desc& md) {
+void NDArray::Chunk::DNNLDataReorder(const void* mem_desc) {
+ const dnnl::memory::desc md = *static_cast<const dnnl::memory::desc*>(mem_desc);
// If the memory already uses the specified layout, don't do anything.
if (dnnl_mem_ != nullptr && dnnl_mem_->SameFormat(md))
return;
@@ -623,7 +626,8 @@
dnnl_mem_.reset(new DNNLMemory(data_md, shandle.dptr));
}
-const dnnl::memory* NDArray::GetDNNLData(const dnnl::memory::desc& desc) const {
+const dnnl::memory* NDArray::GetDNNLData(const void* mem_desc) const {
+ const dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested oneDNN memory desc";
return nullptr;
@@ -639,7 +643,8 @@
}
}
-const dnnl::memory* NDArray::GetDNNLDataReorder(const dnnl::memory::desc& new_desc) const {
+const dnnl::memory* NDArray::GetDNNLDataReorder(const void* mem_desc) const {
+ dnnl::memory::desc new_desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
CHECK(storage_type() == kDefaultStorage);
const dnnl::memory* mem = GetDNNLData();
@@ -774,7 +779,8 @@
return ret;
}
-void NDArray::DNNLDataReorderAsync(const dnnl::memory::desc& desc) const {
+void NDArray::DNNLDataReorderAsync(const void* mem_desc) const {
+ dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
std::vector<Engine::VarHandle> const_vars;
std::vector<Engine::VarHandle> mutable_vars(1, this->var());
NDArray tmp = *this;
@@ -787,7 +793,7 @@
// MXNet will try to reuse NDArray from memory planning, so we need to ensure
// the NDArray is still holding the original trunk data.
if (tmp.version() == version) {
- tmp.ptr_->DNNLDataReorder(desc);
+ tmp.ptr_->DNNLDataReorder(&desc);
}
on_complete();
},
@@ -860,7 +866,8 @@
DNNLMemoryCopy(mem, this_mem);
}
-dnnl::memory* NDArray::CreateDNNLData(const dnnl::memory::desc& desc) {
+dnnl::memory* NDArray::CreateDNNLData(const void* mem_desc) {
+ dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
LOG(FATAL) << "The size of NDArray doesn't match the requested oneDNN memory desc. "
<< "oneDNN memory requests for " << desc.get_size() << " bytes, but got "
@@ -906,7 +913,8 @@
return ptr_->dnnl_mem_->GetRaw();
}
-void NDArray::UpdateDNNLMemDesc(const dnnl::memory::desc& desc) {
+void NDArray::UpdateDNNLMemDesc(const void* mem_desc) {
+ dnnl::memory::desc desc = *static_cast<const dnnl::memory::desc*>(mem_desc);
auto new_desc = desc;
auto this_dtype = get_dnnl_type(dtype());
new_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
diff --git a/src/operator/nn/dnnl/dnnl_act.cc b/src/operator/nn/dnnl/dnnl_act.cc
index 2cc8a34..4b51e45 100644
--- a/src/operator/nn/dnnl/dnnl_act.cc
+++ b/src/operator/nn/dnnl/dnnl_act.cc
@@ -255,8 +255,9 @@
auto input_mem = in_buffer.GetDNNLData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
- if (input_mem->get_desc() != diff_dst_memory->get_desc()) {
- input_mem = in_buffer.GetDNNLDataReorder(diff_dst_memory->get_desc());
+ auto diff_dst_desc = diff_dst_memory->get_desc();
+ if (input_mem->get_desc() != diff_dst_desc) {
+ input_mem = in_buffer.GetDNNLDataReorder(&diff_dst_desc);
}
DNNLActBackward& bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
@@ -264,7 +265,8 @@
dnnl_args_map_t args = {{DNNL_ARG_SRC, *input_mem}, {DNNL_ARG_DIFF_DST, *diff_dst_memory}};
if (req[0] != kAddTo) {
// req[0] is kWriteTo or kWriteInplace
- auto diff_src_memory = const_cast<NDArray&>(in_grad).CreateDNNLData(bwd.bwd_pd.diff_src_desc());
+ auto bwd_pd_diff_src_desc = bwd.bwd_pd.diff_src_desc();
+ auto diff_src_memory = const_cast<NDArray&>(in_grad).CreateDNNLData(&bwd_pd_diff_src_desc);
args.insert({DNNL_ARG_DIFF_SRC, *diff_src_memory});
stream->RegisterPrimArgs(bwd.GetBwd(), args);
stream->Submit();
@@ -301,8 +303,9 @@
auto input_mem = in_buffer.GetDNNLData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
- if (input_mem->get_desc() != diff_dst_memory->get_desc())
- input_mem = in_buffer.GetDNNLDataReorder(diff_dst_memory->get_desc());
+ auto diff_dst_desc = diff_dst_memory->get_desc();
+ if (input_mem->get_desc() != diff_dst_desc)
+ input_mem = in_buffer.GetDNNLDataReorder(&diff_dst_desc);
DNNLActBackward& bwd = GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem);
DNNLStream* stream = DNNLStream::Get();
dnnl_output_t diff_src_memory = CreateDNNLMem(output, bwd.bwd_pd.diff_src_desc(), req[0]);
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index ec50da7..216a420 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -165,7 +165,7 @@
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::AddBack, tmp);
} else if (kWriteInplace == req && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
- dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
+ dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(&desc);
// mem is nullptr if out_arr is view and desc is DNNL format.
// need to Reorder2Default before calling CreateDNNLMem
CHECK(mem != nullptr);
@@ -174,7 +174,7 @@
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::CopyBack, tmp);
} else if (kWriteTo == req) {
- dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
+ dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(&desc);
if (nullptr == mem) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::CopyBack, tmp);
@@ -197,7 +197,7 @@
} else {
dnnl::memory* mem = nullptr;
if (IsDefaultFormat(desc)) {
- mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
+ mem = const_cast<NDArray&>(out_arr).CreateDNNLData(&desc);
}
if (mem == nullptr) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
@@ -214,7 +214,8 @@
} else if (res.first == AddBack) {
auto res_memory = res.second;
auto target_pd = arr.GetDNNLData()->get_desc();
- auto mem = arr.GetDNNLData(res.second->get_desc());
+ auto res_desc = res.second->get_desc();
+ auto mem = arr.GetDNNLData(&res_desc);
if (mem == nullptr) {
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
DNNLMemoryCopy(*res_memory, tmp_memory);
@@ -272,19 +273,19 @@
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
}
const auto md = dnnl::memory::desc{tz, type, format_tag};
- return arr.GetDNNLData(md);
+ return arr.GetDNNLData(&md);
}
const dnnl::memory* GetWeights(const NDArray& arr,
const dnnl::memory::desc& target_desc,
int num_groups) {
- const dnnl::memory* mem = arr.GetDNNLData(target_desc);
+ const dnnl::memory* mem = arr.GetDNNLData(&target_desc);
// If the weight array already uses the target layout, simply return it directly.
if (mem)
return mem;
mem = GetWeights(arr, num_groups);
if (mem == nullptr)
- mem = arr.GetDNNLDataReorder(target_desc);
+ mem = arr.GetDNNLDataReorder(&target_desc);
if (mem->get_desc() == target_desc)
return mem;
diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
index 97f21ae..8152b6c 100644
--- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h
@@ -180,7 +180,8 @@
auto& fwd = GetBNForward<DType>(param, ctx, data_mem, flags);
// for output memory
- auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(fwd.GetPd().dst_desc());
+ auto fwd_dst_desc = fwd.GetPd().dst_desc();
+ auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
// mxnet will always use scale shift.
// But if fix_gamma is true, then all scale elements will be set to 1.0f
@@ -387,10 +388,13 @@
auto diff_mem = diff.GetDNNLData();
// DNNL batchnorm should run on special layouts. If one of them isn't, we
// should reorder them.
- if (data.IsDefaultData())
- data_mem = data.GetDNNLDataReorder(diff_mem->get_desc());
- else if (diff.IsDefaultData())
- diff_mem = diff.GetDNNLDataReorder(data_mem->get_desc());
+ if (data.IsDefaultData()) {
+ auto diff_desc = diff_mem->get_desc();
+ data_mem = data.GetDNNLDataReorder(&diff_desc);
+ } else if (diff.IsDefaultData()) {
+ auto data_desc = data_mem->get_desc();
+ diff_mem = diff.GetDNNLDataReorder(&data_desc);
+ }
auto& bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem =
CreateDNNLMem(const_cast<NDArray&>(gradIn), bwd.pd.diff_src_desc(), req[batchnorm::kData]);
diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc b/src/operator/nn/dnnl/dnnl_convolution.cc
index 7a21290..f28f273 100644
--- a/src/operator/nn/dnnl/dnnl_convolution.cc
+++ b/src/operator/nn/dnnl/dnnl_convolution.cc
@@ -478,7 +478,8 @@
auto& weight = in_data[conv::kWeight];
bool no_bias = param.conv_param.no_bias && !param.dnnl_param.with_bn;
- auto data_mem = data.GetDNNLDataReorder(fwd->GetPd().src_desc());
+ auto fwd_src_desc = fwd->GetPd().src_desc();
+ auto data_mem = data.GetDNNLDataReorder(&fwd_src_desc);
const dnnl::memory* weight_mem;
if (ctx.is_train) {
// TODO(zhengda) kvstore doesn't handle DNNL correctly. Let's reorder it to the default format
@@ -493,10 +494,12 @@
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array. The data conversion happens
// after the weight array is used.
- weight.DNNLDataReorderAsync(fwd->GetPd().weights_desc());
+ auto fwd_weight_desc = fwd->GetPd().weights_desc();
+ weight.DNNLDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group);
} else {
- weight_mem = weight.GetDNNLDataReorder(fwd->GetPd().weights_desc());
+ auto fwd_weight_desc = fwd->GetPd().weights_desc();
+ weight_mem = weight.GetDNNLDataReorder(&fwd_weight_desc);
}
}
dnnl_output_t out_mem;
@@ -599,8 +602,9 @@
const ConvolutionParam& param = full_param.conv_param;
CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
- DNNLConvBackward& convBwd = GetConvBwd(full_param, data, weight, bias, out_grad);
- auto out_grad_mem = out_grad.GetDNNLDataReorder(convBwd.GetDataPd().diff_dst_desc());
+ DNNLConvBackward& convBwd = GetConvBwd(full_param, data, weight, bias, out_grad);
+ auto convBwd_data_diff_desc = convBwd.GetDataPd().diff_dst_desc();
+ auto out_grad_mem = out_grad.GetDNNLDataReorder(&convBwd_data_diff_desc);
if (req[conv::kData]) {
auto weight_mem = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group);
auto in_grad_mem =
@@ -615,10 +619,13 @@
auto req_weight = req.size() > conv::kWeight ? req.at(conv::kWeight) : kNullOp;
auto req_bias = req.size() > conv::kBias ? req.at(conv::kBias) : kNullOp;
if (req_weight || req_bias) {
- if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc())
- out_grad_mem = out_grad.GetDNNLDataReorder(convBwd.GetWeightsPd().diff_dst_desc());
- auto data_mem = data.GetDNNLDataReorder(convBwd.GetWeightsPd().src_desc());
- auto in_grad_weight = CreateDNNLWeightGrad(
+ if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc()) {
+ auto convBwd_weight_diff_desc = convBwd.GetWeightsPd().diff_dst_desc();
+ out_grad_mem = out_grad.GetDNNLDataReorder(&convBwd_weight_diff_desc);
+ }
+ auto convBwd_weight_src_desc = convBwd.GetWeightsPd().src_desc();
+ auto data_mem = data.GetDNNLDataReorder(&convBwd_weight_src_desc);
+ auto in_grad_weight = CreateDNNLWeightGrad(
in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]);
dnnl_args_map_t net_args = {{DNNL_ARG_DIFF_DST, *out_grad_mem},
diff --git a/src/operator/nn/dnnl/dnnl_copy.cc b/src/operator/nn/dnnl/dnnl_copy.cc
index 0fa9dc1..16cbabd 100644
--- a/src/operator/nn/dnnl/dnnl_copy.cc
+++ b/src/operator/nn/dnnl/dnnl_copy.cc
@@ -43,10 +43,11 @@
TmpMemMgr::Get()->Init(ctx.requested[0]);
// We should try and force the input memory has the same format
// as the input output. If not, we'll have to reorder memory.
- auto out_mem = out_data.GetDNNLData();
- in_mem = in_data.GetDNNLData(out_mem->get_desc());
+ auto out_mem = out_data.GetDNNLData();
+ auto out_mem_desc = out_mem->get_desc();
+ in_mem = in_data.GetDNNLData(&out_mem_desc);
if (in_mem == nullptr)
- in_mem = in_data.GetDNNLDataReorder(out_mem->get_desc());
+ in_mem = in_data.GetDNNLDataReorder(&out_mem_desc);
DNNLSum(*out_mem, *in_mem, *out_mem);
} else {
const_cast<NDArray&>(out_data).CopyFrom(*in_mem);
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
index 1078423..a1ac551 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
+++ b/src/operator/nn/dnnl/dnnl_deconvolution-inl.h
@@ -82,7 +82,8 @@
temp.data_type(),
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(temp.data.ndims)));
}
- const_cast<NDArray&>(arr).UpdateDNNLMemDesc(IOLogicalSwapDesc(desc, num_group));
+ auto iOLogicalSwapDesc = IOLogicalSwapDesc(desc, num_group);
+ const_cast<NDArray&>(arr).UpdateDNNLMemDesc(&iOLogicalSwapDesc);
}
// Version of GetWeightsDesc for deconvolution (with swap)
@@ -149,7 +150,8 @@
}
inline const dnnl::memory* DNNLDeconvFwd::DataMem(const NDArray& data) const {
- return data.GetDNNLDataReorder(fwd_pd->src_desc());
+ auto fwd_src_desc = fwd_pd->src_desc();
+ return data.GetDNNLDataReorder(&fwd_src_desc);
}
inline const dnnl::memory* DNNLDeconvFwd::WeightsMem(const uint32_t num_group,
@@ -275,7 +277,8 @@
}
inline const dnnl::memory* DNNLDeconvBwd::DataMem(const NDArray& data) const {
- return data.GetDNNLDataReorder(bwd_weights_pd->src_desc());
+ auto bwd_weight_src_desc = bwd_weights_pd->src_desc();
+ return data.GetDNNLDataReorder(&bwd_weight_src_desc);
}
inline const dnnl::memory* DNNLDeconvBwd::WeightsMem(const uint32_t num_group,
@@ -284,14 +287,16 @@
}
inline const dnnl::memory* DNNLDeconvBwd::OutGradMem(const NDArray& out_grad) const {
- return out_grad.GetDNNLDataReorder(bwd_data_pd->diff_dst_desc());
+ auto bwd_data_diff_desc = bwd_data_pd->diff_dst_desc();
+ return out_grad.GetDNNLDataReorder(&bwd_data_diff_desc);
}
inline const dnnl::memory* DNNLDeconvBwd::OutGradMem(const NDArray& out_grad,
const dnnl::memory* const out_grad_mem) const {
- return (out_grad_mem && out_grad_mem->get_desc() == bwd_weights_pd->diff_dst_desc()) ?
+ auto bwd_weight_diff_desc = bwd_weights_pd->diff_dst_desc();
+ return (out_grad_mem && out_grad_mem->get_desc() == bwd_weight_diff_desc) ?
out_grad_mem :
- out_grad.GetDNNLDataReorder(bwd_weights_pd->diff_dst_desc());
+ out_grad.GetDNNLDataReorder(&bwd_weight_diff_desc);
}
inline dnnl_output_t DNNLDeconvBwd::DataGradMem(const OpReqType req,
@@ -308,7 +313,7 @@
// swap, weights_md will have a default format
const auto& weights_md = bwd_weights_pd->diff_weights_desc();
if (req == OpReqType::kWriteTo && IsDefaultFormat(IOLogicalSwapDesc(weights_md, num_group))) {
- return {OutDataOp::Noop, const_cast<NDArray&>(weights_grad).CreateDNNLData(weights_md)};
+ return {OutDataOp::Noop, const_cast<NDArray&>(weights_grad).CreateDNNLData(&weights_md)};
}
return CreateDNNLWeightGrad(weights_grad, weights_md, req);
}
diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc b/src/operator/nn/dnnl/dnnl_deconvolution.cc
index 9140d19..9487574 100644
--- a/src/operator/nn/dnnl/dnnl_deconvolution.cc
+++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc
@@ -112,7 +112,8 @@
if (weights.IsDefaultData()) {
// We also need to modify the layout on the original weights array.
// The data conversion happens after the weights array is used.
- weights.DNNLDataReorderAsync(IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
+ auto logical_swap_desc = IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group);
+ weights.DNNLDataReorderAsync(&logical_swap_desc);
} else {
CHECK(weights.GetDNNLData()->get_desc() ==
IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group));
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected.cc b/src/operator/nn/dnnl/dnnl_fully_connected.cc
index 6f04b19..bf58135 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected.cc
+++ b/src/operator/nn/dnnl/dnnl_fully_connected.cc
@@ -186,10 +186,10 @@
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
- NDArray weight = in_data[fullc::kWeight];
- NDArray data = in_data[fullc::kData];
-
- auto data_mem = data.GetDNNLDataReorder(fwd->fwd_pd.src_desc());
+ NDArray weight = in_data[fullc::kWeight];
+ NDArray data = in_data[fullc::kData];
+ auto fwd_src_desc = fwd->fwd_pd.src_desc();
+ auto data_mem = data.GetDNNLDataReorder(&fwd_src_desc);
const dnnl::memory* weight_mem;
if (ctx.is_train) {
if (weight.IsDNNLData()) {
@@ -199,7 +199,8 @@
} else {
weight_mem = weight.GetDNNLData();
if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
- weight.DNNLDataReorderAsync(fwd->fwd_pd.weights_desc());
+ auto fwd_weight_desc = fwd->fwd_pd.weights_desc();
+ weight.DNNLDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
}
}
@@ -212,7 +213,8 @@
{DNNL_ARG_DST, *out_mem.second},
};
if (!full_param.default_param.no_bias) {
- auto bias_mem = in_data[fullc::kBias].GetDNNLDataReorder(fwd->fwd_pd.bias_desc());
+ auto fwd_bias_desc = fwd->fwd_pd.bias_desc();
+ auto bias_mem = in_data[fullc::kBias].GetDNNLDataReorder(&fwd_bias_desc);
args[DNNL_ARG_BIAS] = *bias_mem;
}
DNNLStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
@@ -286,9 +288,11 @@
if (req[fullc::kWeight]) {
dnnl::inner_product_backward_weights::primitive_desc ipBwdWeights_pd = GetFCBwdWeights(
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], out_grad, fwd_pd);
- auto out_grad_mem = out_grad.GetDNNLDataReorder(ipBwdWeights_pd.diff_dst_desc());
- auto data_mem = data.GetDNNLDataReorder(ipBwdWeights_pd.src_desc());
- auto in_grad_weight = CreateDNNLWeightGrad(
+ auto ipBwdWeights_diff_dst_desc = ipBwdWeights_pd.diff_dst_desc();
+ auto ipBwdWeights_src_desc = ipBwdWeights_pd.src_desc();
+ auto out_grad_mem = out_grad.GetDNNLDataReorder(&ipBwdWeights_diff_dst_desc);
+ auto data_mem = data.GetDNNLDataReorder(&ipBwdWeights_src_desc);
+ auto in_grad_weight = CreateDNNLWeightGrad(
in_grad[fullc::kWeight], ipBwdWeights_pd.diff_weights_desc(), req[fullc::kWeight]);
dnnl_args_map_t args = {
{DNNL_ARG_DIFF_DST, *out_grad_mem},
@@ -312,8 +316,10 @@
if (req[fullc::kData]) {
dnnl::inner_product_backward_data::primitive_desc ipBwdData_pd =
GetFCBwdData(data, weight, out_grad, fwd_pd);
- auto out_grad_mem = out_grad.GetDNNLDataReorder(ipBwdData_pd.diff_dst_desc());
- auto weight_mem = weight.GetDNNLDataReorder(ipBwdData_pd.weights_desc());
+ auto ipBwdData_diff_dst_desc = ipBwdData_pd.diff_dst_desc();
+ auto ipBwdData_weight_desc = ipBwdData_pd.weights_desc();
+ auto out_grad_mem = out_grad.GetDNNLDataReorder(&ipBwdData_diff_dst_desc);
+ auto weight_mem = weight.GetDNNLDataReorder(&ipBwdData_weight_desc);
auto in_grad_mem =
CreateDNNLMem(in_grad[fullc::kData], ipBwdData_pd.diff_src_desc(), req[fullc::kData]);
dnnl_args_map_t args = {{DNNL_ARG_DIFF_DST, *out_grad_mem},
diff --git a/src/operator/nn/dnnl/dnnl_layer_norm.cc b/src/operator/nn/dnnl/dnnl_layer_norm.cc
index 4108a62..65eb0a3 100644
--- a/src/operator/nn/dnnl/dnnl_layer_norm.cc
+++ b/src/operator/nn/dnnl/dnnl_layer_norm.cc
@@ -134,10 +134,11 @@
const std::vector<NDArray>& outputs) const {
auto mean_var_md = GetMeanVarDesc(get_dnnl_type(outputs[layernorm::kMean].dtype()),
outputs[layernorm::kMean].shape());
- auto mean_mem = dnnl_output_t(
- OutDataOp::Noop, const_cast<NDArray&>(outputs[layernorm::kMean]).CreateDNNLData(mean_var_md));
+ auto mean_mem =
+ dnnl_output_t(OutDataOp::Noop,
+ const_cast<NDArray&>(outputs[layernorm::kMean]).CreateDNNLData(&mean_var_md));
auto variance_mem = dnnl_output_t(
- OutDataOp::Noop, const_cast<NDArray&>(outputs[layernorm::kStd]).CreateDNNLData(mean_var_md));
+ OutDataOp::Noop, const_cast<NDArray&>(outputs[layernorm::kStd]).CreateDNNLData(&mean_var_md));
auto output_mem = CreateDNNLMem(outputs[layernorm::kOut], fwd_pd->dst_desc(), req);
auto scale_shift_mem = GetScaleShiftMem(inputs[layernorm::kGamma], inputs[layernorm::kBeta]);
@@ -183,7 +184,8 @@
const std::vector<OpReqType>& req) const {
auto scale_shift_mem =
GetScaleShiftMem(inputs[layernorm::kBwdGamma], inputs[layernorm::kBwdBeta]);
- auto diff_weights_ndarray = NDArray(scale_shift_mem.get_desc());
+ auto scale_shift_mem_desc = scale_shift_mem.get_desc();
+ auto diff_weights_ndarray = NDArray(&scale_shift_mem_desc);
const auto bytes = inputs[layernorm::kBwdGamma].shape()[0] *
mshadow::mshadow_sizeof(inputs[layernorm::kBwdGamma].dtype());
const auto diff_weights_ndaray_data_ptr_plus_bytes = reinterpret_cast<void*>(
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index 9243440..be1abdb 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -127,9 +127,10 @@
int axis = CheckAxis(param.axis, in_data.shape().ndim());
auto fwd = GetLogSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data);
- auto in_mem = in_data.GetDNNLData();
- auto out_mem = out_data.GetDNNLData(fwd.pd.dst_desc());
- DNNLStream* stream = DNNLStream::Get();
+ auto in_mem = in_data.GetDNNLData();
+ auto fwd_pd_dst_desc = fwd.pd.dst_desc();
+ auto out_mem = out_data.GetDNNLData(&fwd_pd_dst_desc);
+ DNNLStream* stream = DNNLStream::Get();
stream->RegisterPrimArgs(fwd.GetFwd(), {{DNNL_ARG_SRC, *in_mem}, {DNNL_ARG_DST, *out_mem}});
stream->Submit();
}
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax.cc b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
index 789a6bf..a2a9c58 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
@@ -169,10 +169,11 @@
p.axis = param.axis;
p.temperature = param.temperature;
- auto softmax_tensors = DNNLSoftmaxFwd::Tensors(output, output);
- auto softmax_op = DNNLSoftmaxFwd::GetCached(p, softmax_tensors, is_train);
- auto softmax_out_mem = output.GetDNNLData(softmax_op.softmax_pd->dst_desc());
- const auto input_mem = input.GetDNNLData();
+ auto softmax_tensors = DNNLSoftmaxFwd::Tensors(output, output);
+ auto softmax_op = DNNLSoftmaxFwd::GetCached(p, softmax_tensors, is_train);
+ auto softmax_op_dst_desc = softmax_op.softmax_pd->dst_desc();
+ auto softmax_out_mem = output.GetDNNLData(&softmax_op_dst_desc);
+ const auto input_mem = input.GetDNNLData();
// 1. C) out = input * out
stream->RegisterPrimArgs(this->primitives->mask_input,
diff --git a/src/operator/nn/dnnl/dnnl_pooling.cc b/src/operator/nn/dnnl/dnnl_pooling.cc
index 85e72bb..c2dfe36 100644
--- a/src/operator/nn/dnnl/dnnl_pooling.cc
+++ b/src/operator/nn/dnnl/dnnl_pooling.cc
@@ -431,10 +431,11 @@
TmpMemMgr::Get()->Init(ctx.requested[0]);
- auto& bwd = GetPoolingBwd(param, *in_data, in_grad, out_grad, param.IsAdaptivePooling());
- auto diff_dst_mem = out_grad.GetDNNLDataReorder(bwd.pd.diff_dst_desc());
- auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]);
- dnnl_args_map_t args = {
+ auto& bwd = GetPoolingBwd(param, *in_data, in_grad, out_grad, param.IsAdaptivePooling());
+ auto bwd_diff_dst_desc = bwd.pd.diff_dst_desc();
+ auto diff_dst_mem = out_grad.GetDNNLDataReorder(&bwd_diff_dst_desc);
+ auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]);
+ dnnl_args_map_t args = {
{DNNL_ARG_DIFF_DST, *diff_dst_mem},
{DNNL_ARG_DIFF_SRC, *diff_src_mem.second},
};
diff --git a/src/operator/nn/dnnl/dnnl_reduce.cc b/src/operator/nn/dnnl/dnnl_reduce.cc
index f486c2f..a9be0af 100644
--- a/src/operator/nn/dnnl/dnnl_reduce.cc
+++ b/src/operator/nn/dnnl/dnnl_reduce.cc
@@ -225,7 +225,8 @@
auto out_mem = dnnl::memory(reduce_pd->dst_desc(), engine, tensors.out.data().dptr<float>());
stream->RegisterPrimArgs(*reduce_fwd, {{DNNL_ARG_SRC, *input_mem}, {DNNL_ARG_DST, out_mem}});
} else {
- auto out_mem = tensors.out.GetDNNLData(reduce_pd->dst_desc());
+ auto desc = reduce_pd->dst_desc();
+ auto out_mem = tensors.out.GetDNNLData(&desc);
stream->RegisterPrimArgs(*reduce_fwd, {{DNNL_ARG_SRC, *input_mem}, {DNNL_ARG_DST, *out_mem}});
}
stream->Submit();
diff --git a/src/operator/nn/dnnl/dnnl_softmax-inl.h b/src/operator/nn/dnnl/dnnl_softmax-inl.h
index c681f33..42558c6 100644
--- a/src/operator/nn/dnnl/dnnl_softmax-inl.h
+++ b/src/operator/nn/dnnl/dnnl_softmax-inl.h
@@ -81,7 +81,6 @@
std::shared_ptr<linear_t> temperature_fwd;
};
-
class DNNLSoftmaxBwd {
public:
struct Tensors {
@@ -107,7 +106,6 @@
std::shared_ptr<linear_t> temperature_fwd;
};
-
} // namespace op
} // namespace mxnet
#endif
diff --git a/src/operator/nn/dnnl/dnnl_softmax.cc b/src/operator/nn/dnnl/dnnl_softmax.cc
index 93fe557..73321a1 100644
--- a/src/operator/nn/dnnl/dnnl_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax.cc
@@ -139,8 +139,9 @@
void DNNLSoftmaxFwd::Execute(const Tensors& tensors) const {
DNNLStream* stream = DNNLStream::Get();
- auto original_input_mem = tensors.data.GetDNNLData();
- const auto out_mem = tensors.out.GetDNNLData(softmax_pd->dst_desc());
+ auto original_input_mem = tensors.data.GetDNNLData();
+ auto softmax_pd_dst_desc = softmax_pd->dst_desc();
+ const auto out_mem = tensors.out.GetDNNLData(&softmax_pd_dst_desc);
dnnl::memory* softmax_input_mem;
if (temperature_pd) {
diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc
index c2335b9..7d5ca4a 100644
--- a/src/operator/nn/dnnl/dnnl_where.cc
+++ b/src/operator/nn/dnnl/dnnl_where.cc
@@ -171,9 +171,13 @@
const auto& cpu_engine = CpuEngine::Get()->get_engine();
const auto& cpu_stream = ctx.get_stream<cpu>();
- const auto& cnd_tensor = tensors.condition.GetDNNLDataReorder(binary_eq_zero_pd.src0_desc());
- const auto& lhs_tensor = tensors.left.GetDNNLDataReorder(binary_mul_l_pd.src0_desc());
- const auto& rhs_tensor = tensors.right.GetDNNLDataReorder(binary_mul_r_pd.src0_desc());
+ auto binary_eq_zero_pd_desc = binary_eq_zero_pd.src0_desc();
+ auto binary_mul_l_pd_desc = binary_mul_l_pd.src0_desc();
+ auto binary_mul_r_pd_desc = binary_mul_r_pd.src0_desc();
+
+ const auto& cnd_tensor = tensors.condition.GetDNNLDataReorder(&binary_eq_zero_pd_desc);
+ const auto& lhs_tensor = tensors.left.GetDNNLDataReorder(&binary_mul_l_pd_desc);
+ const auto& rhs_tensor = tensors.right.GetDNNLDataReorder(&binary_mul_r_pd_desc);
mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, binary_sum_pd.dst_desc(), req[0]);
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
index 3f13775..ba3b2f8 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_batch_norm.cc
@@ -114,7 +114,8 @@
}
const NDArray& out = outputs[batchnorm::kOut];
- auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(fwd.GetPd().dst_desc());
+ auto fwd_dst_desc = fwd.GetPd().dst_desc();
+ auto out_mem = const_cast<NDArray&>(out).CreateDNNLData(&fwd_dst_desc);
dnnl_args_map_t net_args;
net_args[DNNL_ARG_SRC] = *data_mem;
net_args[DNNL_ARG_SCALE_SHIFT] = weight_mem;
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_conv.cc b/src/operator/quantization/dnnl/dnnl_quantized_conv.cc
index 158d0ea..85b915e 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_conv.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_conv.cc
@@ -45,13 +45,14 @@
DNNLConvFullParam full_param;
full_param.conv_param = param;
full_param.dnnl_param.Init(std::unordered_map<std::string, std::string>());
- auto& fwd = GetConvFwd(full_param,
+ auto& fwd = GetConvFwd(full_param,
ctx.is_train,
in_data[conv::kData],
in_data[conv::kWeight],
param.no_bias ? nullptr : &in_data[conv::kBias],
out_data[conv::kOut]);
- auto data_mem = in_data[conv::kData].GetDNNLDataReorder(fwd.GetPd().src_desc());
+ auto fwd_src_desc = fwd.GetPd().src_desc();
+ auto data_mem = in_data[conv::kData].GetDNNLDataReorder(&fwd_src_desc);
const dnnl::memory* weight_mem;
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
@@ -59,15 +60,17 @@
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
- weight.DNNLDataReorderAsync(fwd.GetPd().weights_desc());
- weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
+ auto fwd_weight_desc = fwd.GetPd().weights_desc();
+ weight.DNNLDataReorderAsync(&fwd_weight_desc);
+ weight_mem = GetWeights(weight, fwd_weight_desc, param.num_group);
} else {
weight_mem = weight.GetDNNLData();
}
auto out_mem = CreateDNNLMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]);
dnnl_args_map_t net_args;
if (!param.no_bias) {
- const dnnl::memory* bias_mem = in_data[conv::kBias].GetDNNLDataReorder(fwd.GetPd().bias_desc());
+ auto fwd_bias_desc = fwd.GetPd().bias_desc();
+ const dnnl::memory* bias_mem = in_data[conv::kBias].GetDNNLDataReorder(&fwd_bias_desc);
net_args.insert({DNNL_ARG_BIAS, *bias_mem});
}
net_args.insert({DNNL_ARG_SRC, *data_mem});
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc b/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc
index 7746129..ca75f54 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_fully_connected.cc
@@ -90,15 +90,17 @@
auto& fwd =
GetFCFwd(param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md);
- auto data_mem = in_data[fullc::kData].GetDNNLDataReorder(fwd.fwd_pd.src_desc());
+ auto fwd_src_desc = fwd.fwd_pd.src_desc();
+ auto data_mem = in_data[fullc::kData].GetDNNLDataReorder(&fwd_src_desc);
const dnnl::memory* weight_mem = nullptr;
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
- weight.DNNLDataReorderAsync(fwd.fwd_pd.weights_desc());
- weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
+ auto fwd_weight_desc = fwd.fwd_pd.weights_desc();
+ weight.DNNLDataReorderAsync(&fwd_weight_desc);
+ weight_mem = GetWeights(weight, fwd_weight_desc, 1);
} else {
weight_mem = weight.GetDNNLData();
CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
@@ -113,7 +115,8 @@
const dnnl::memory* bias_mem = nullptr;
if (!param.no_bias) {
- bias_mem = quantized_bias.GetDNNLDataReorder(fwd.fwd_pd.bias_desc());
+ auto fwd_bias_desc = fwd.fwd_pd.bias_desc();
+ bias_mem = quantized_bias.GetDNNLDataReorder(&fwd_bias_desc);
args[DNNL_ARG_BIAS] = *bias_mem;
}
diff --git a/src/operator/subgraph/dnnl/dnnl_common.h b/src/operator/subgraph/dnnl/dnnl_common.h
index 68f10c0..3db25aa 100644
--- a/src/operator/subgraph/dnnl/dnnl_common.h
+++ b/src/operator/subgraph/dnnl/dnnl_common.h
@@ -103,7 +103,7 @@
const std::vector<float>& weight_scales,
const bool submit = true) {
DNNLStream* stream = DNNLStream::Get();
- const auto new_weight = NDArray(weight_md);
+ const auto new_weight = NDArray(&weight_md);
const auto conv_weights_memory = new_weight.GetDNNLData();
dnnl::primitive_attr weight_attr;
if (weight_scales.size()) {
@@ -124,7 +124,7 @@
for (size_t c = 0; c < weight_scales.size(); ++c) {
bias_scales[c] = weight_scales[c] * data_scale;
}
- new_bias = NDArray(*bias_md);
+ new_bias = NDArray(bias_md);
const auto conv_bias_memory = new_bias.GetDNNLData();
const int bias_mask = (bias_scales.size()) == 1 ? 0 : 1;
dnnl::primitive_attr bias_attr;
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc
index 2712ee9..2627460 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -358,7 +358,7 @@
const auto& out_mem_desc = output_mem->get_desc();
const auto& dst_mem_desc = fwd_->GetPd().dst_desc();
if (out_mem_desc != dst_mem_desc) {
- auto tmp_out_mem = output.GetDNNLDataReorder(fwd_->GetPd().dst_desc());
+ auto tmp_out_mem = output.GetDNNLDataReorder(&dst_mem_desc);
auto data_md = dst_mem_desc;
data_md.data.data_type = static_cast<dnnl_data_type_t>(out_mem_desc.data.data_type);
dnnl_mem_ptr new_out_mem(
@@ -370,10 +370,12 @@
}
if (dnnl_param.quantized) {
- auto data_mem = data.GetDNNLDataReorder(fwd_->GetPd().src_desc());
- dnnl::memory* mem = output.CreateDNNLData(fwd_->GetPd().dst_desc());
- args_[DNNL_ARG_SRC] = *data_mem;
- args_[DNNL_ARG_DST] = *mem;
+ auto fwd_src_desc = fwd_->GetPd().src_desc();
+ auto data_mem = data.GetDNNLDataReorder(&fwd_src_desc);
+ auto fwd_pd_dst_desc = fwd_->GetPd().dst_desc();
+ dnnl::memory* mem = output.CreateDNNLData(&fwd_pd_dst_desc);
+ args_[DNNL_ARG_SRC] = *data_mem;
+ args_[DNNL_ARG_DST] = *mem;
DNNLStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
DNNLStream::Get()->Submit();
} else {
@@ -391,8 +393,9 @@
*outputs[kMax].data().dptr<float>() = cached_output_max_;
}
if (dnnl_param.with_sum) {
- auto out = const_cast<NDArray&>(outputs[kOut]);
- out.UpdateDNNLMemDesc(fwd_->GetPd().dst_desc());
+ auto out = const_cast<NDArray&>(outputs[kOut]);
+ auto fwd_dst_desc = fwd_->GetPd().dst_desc();
+ out.UpdateDNNLMemDesc(&fwd_dst_desc);
}
}
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc
index f9d0d57..a5c199f 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -417,7 +417,7 @@
const auto def_weight_mem = static_cast<const dnnl::memory*>(weight.GetDNNLData());
if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
auto weight_desc = fwd_->fwd_pd.weights_desc();
- cached_weight_ = NDArray(weight_desc);
+ cached_weight_ = NDArray(&weight_desc);
auto cached_weight_mem = static_cast<const dnnl::memory*>(cached_weight_.GetDNNLData());
std::unordered_map<int, dnnl::memory> args(
{{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, *cached_weight_mem}});
@@ -442,7 +442,7 @@
const auto& out_mem_desc = output_mem->get_desc();
auto dst_mem_desc = fwd_->fwd_pd.dst_desc();
if (out_mem_desc != dst_mem_desc) {
- auto tmp_out_mem = output.GetDNNLDataReorder(dst_mem_desc);
+ auto tmp_out_mem = output.GetDNNLDataReorder(&dst_mem_desc);
dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
dnnl_mem_ptr new_out_mem(new dnnl::memory(
dst_mem_desc, CpuEngine::Get()->get_engine(), output_mem->get_data_handle()));
diff --git a/tests/cpp/include/test_dnnl.h b/tests/cpp/include/test_dnnl.h
index 7172b0b..166074e 100644
--- a/tests/cpp/include/test_dnnl.h
+++ b/tests/cpp/include/test_dnnl.h
@@ -86,7 +86,7 @@
bool is_rand = false,
int max = 50) {
InitDefaultArray(arr, is_rand, max);
- arr->DNNLDataReorderAsync(desc);
+ arr->DNNLDataReorderAsync(&desc);
arr->WaitToRead();
}
diff --git a/tests/cpp/operator/dnnl_test.cc b/tests/cpp/operator/dnnl_test.cc
index 510d368..2874df1 100644
--- a/tests/cpp/operator/dnnl_test.cc
+++ b/tests/cpp/operator/dnnl_test.cc
@@ -139,7 +139,7 @@
InitDefaultArray(&arr);
for (auto md : mds) {
if (s.Size() == md.get_size() / sizeof(mshadow::default_real_t)) {
- const dnnl::memory* mem = arr.GetDNNLDataReorder(md);
+ const dnnl::memory* mem = arr.GetDNNLDataReorder(&md);
printf("reorder from (");
for (size_t i = 0; i < s.ndim(); i++)
printf("%ld, ", s[i]);
@@ -171,7 +171,7 @@
InitDNNLArray(&arr, md);
for (auto to_md : mds) {
if (to_md.get_size() / sizeof(mshadow::default_real_t) == s.Size()) {
- const dnnl::memory* mem = arr.GetDNNLDataReorder(to_md);
+ const dnnl::memory* mem = arr.GetDNNLDataReorder(&to_md);
printf("reorder from (");
for (size_t i = 0; i < s.ndim(); i++)
printf("%ld, ", s[i]);