blob: ec50da735f3916543c44bf481d25c555649b96ec [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#if MXNET_USE_ONEDNN == 1
#include <atomic>
#include "../../../common/exec_utils.h"
#include "operator/operator_common.h"
#include "dnnl_base-inl.h"
#include "dnnl_ops-inl.h"
namespace mxnet {
DNNLStream* DNNLStream::Get() {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local DNNLStream stream;
#else
static MX_THREAD_LOCAL DNNLStream stream;
#endif
return &stream;
}
void* AlignMem(void* mem, size_t size, size_t alignment, size_t* space) {
if (size > *space)
return nullptr;
intptr_t addr = reinterpret_cast<intptr_t>(mem);
// If the address has been aligned, don't do anything.
intptr_t last_chunk = addr % alignment;
if (last_chunk == 0)
return mem;
intptr_t padding = alignment - last_chunk;
// If the buffer doesn't have enough space, we should return null here.
if (padding + size > *space)
return nullptr;
addr += padding;
*space -= padding;
CHECK_EQ(addr % alignment, 0);
return reinterpret_cast<void*>(addr);
}
dnnl::memory* TmpMemMgr::Alloc(const dnnl::memory::desc& md) {
// We need to include the size of the memory used for alignment.
this->est_size += md.get_size() + alignment;
void* mem = AlignMem(this->curr_mem, md.get_size(), alignment, &this->curr_size);
if (mem) {
// The memory is allocated from the temporary memory space in the
// operator. It'll only become invalid after we exit from the operator.
dnnl_mem_ptr ret(new dnnl::memory(md, CpuEngine::Get()->get_engine(), mem));
DNNLStream::Get()->RegisterMem(ret);
CHECK_EQ(mem, mem);
this->curr_size -= md.get_size();
this->curr_mem = static_cast<char*>(mem) + md.get_size();
return ret.get();
} else {
// If curr_mem has been initialized and we still reach here, it means the current
// allocated memory isn't enough. But it doesn't matter for multiple invokes of a
// operator, as the TmpMemMgr could estimate the space at the first iteration and
// then re-requests abundant space from MXNet resource. DNNL could allocate
// the space by itself. Thus, we just let it continue for estimating the maximum
// required space size. It will be allocated at next call.
if (this->curr_mem && dmlc::GetEnv("MXNET_ONEDNN_DEBUG", false)) {
LOG(WARNING) << "oneDNN debug message: The rest of the temporary space is not "
<< "adequate for allocating " << md.get_size() << " bytes. Thus, oneDNN "
<< "allocate the space by itself.";
}
dnnl_mem_ptr ret(new dnnl::memory(md, CpuEngine::Get()->get_engine()));
DNNLStream::Get()->RegisterMem(ret);
return ret.get();
}
}
void DNNLMemoryCopy(const dnnl::memory& mem, const dnnl::memory* this_mem) {
DNNLStream* stream = DNNLStream::Get();
dnnl::memory::desc from_desc = mem.get_desc();
dnnl::memory::desc this_desc = this_mem->get_desc();
dnnl_format_tag_t from_def_format = GetDefaultFormat(from_desc);
dnnl_format_tag_t this_def_format = GetDefaultFormat(this_desc);
if (!same_shape(this_desc, from_desc) && IsDefaultFormat(from_desc)) {
// In this case, we can simply create a new DNNL memory for the required
// shape.
dnnl::memory::dims dims(this_desc.data.dims, this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<dnnl::memory::data_type>(this_desc.data.data_type);
dnnl::memory::desc data_md(
dims, this_dtype, static_cast<dnnl::memory::format_tag>(this_def_format));
dnnl_mem_ptr tmp_mem(new dnnl::memory(data_md, mem.get_engine(), mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
std::unordered_map<int, dnnl::memory> args(
{{DNNL_ARG_FROM, *tmp_mem}, {DNNL_ARG_TO, *this_mem}});
stream->RegisterPrimArgs(dnnl::reorder(*tmp_mem, *this_mem), args);
} else if (!same_shape(this_desc, from_desc)) {
// In this case, the source memory stores data in a customized layout. We
// need to reorganize the data in memory before we can reshape.
dnnl::memory::desc def_desc = GetDesc(from_desc, from_def_format);
dnnl::memory* def_mem = TmpMemMgr::Get()->Alloc(def_desc);
std::unordered_map<int, dnnl::memory> args({{DNNL_ARG_FROM, mem}, {DNNL_ARG_TO, *def_mem}});
stream->RegisterPrimArgs(dnnl::reorder(mem, *def_mem), args);
// Now we can reshape it
dnnl_mem_ptr tmp_mem(new dnnl::memory(this_desc, mem.get_engine(), def_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
args = {{DNNL_ARG_FROM, *tmp_mem}, {DNNL_ARG_TO, *this_mem}};
stream->RegisterPrimArgs(dnnl::reorder(*tmp_mem, *this_mem), args);
} else if (this_desc == from_desc) {
std::unordered_map<int, dnnl::memory> args({{DNNL_ARG_FROM, mem}, {DNNL_ARG_TO, *this_mem}});
// If the layout is the same, we can just copy data.
stream->RegisterPrimArgs(dnnl::reorder(mem, *this_mem), args);
} else {
// If both are not using the default layouts. There isn't much we can do,
// other than reorder data layout directly.
if (!IsDefaultFormat(this_desc) && !IsDefaultFormat(from_desc)) {
std::unordered_map<int, dnnl::memory> args({{DNNL_ARG_FROM, mem}, {DNNL_ARG_TO, *this_mem}});
stream->RegisterPrimArgs(dnnl::reorder(mem, *this_mem), args);
} else if (IsDefaultFormat(this_desc)) {
// If the dest mem uses the default memory layout, we can simply use
// the default format of the source memory to improve perf of reorder.
dnnl::memory::desc desc = GetDesc(from_desc, from_def_format);
dnnl_mem_ptr tmp_mem(new dnnl::memory(desc, mem.get_engine(), this_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
std::unordered_map<int, dnnl::memory> args({{DNNL_ARG_FROM, mem}, {DNNL_ARG_TO, *tmp_mem}});
stream->RegisterPrimArgs(dnnl::reorder(mem, *tmp_mem), args);
} else {
// If the src mem uses the default memory layout, we can use
// the default format of the source memory to improve perf.
dnnl::memory::desc desc = GetDesc(this_desc, this_def_format);
dnnl_mem_ptr tmp_mem(new dnnl::memory(desc, this_mem->get_engine(), mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
std::unordered_map<int, dnnl::memory> args(
{{DNNL_ARG_FROM, *tmp_mem}, {DNNL_ARG_TO, *this_mem}});
stream->RegisterPrimArgs(dnnl::reorder(*tmp_mem, *this_mem), args);
}
}
}
bool CanWriteTo(const NDArray& out_arr, const NDArray& in_arr, const dnnl::memory::desc& desc) {
auto in_mem = in_arr.GetDNNLData();
bool add_same = in_mem->get_data_handle() == out_arr.GetDNNLData()->get_data_handle();
bool pdesc_same = out_arr.GetDNNLData()->get_desc() == desc && in_mem->get_desc() == desc;
return add_same && pdesc_same;
}
dnnl_output_t CreateDNNLMem(const NDArray& out_arr,
const dnnl::memory::desc& desc,
OpReqType req,
const NDArray* in_arr) {
if (kAddTo == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::AddBack, tmp);
} else if (kWriteInplace == req && in_arr != nullptr && CanWriteTo(out_arr, *in_arr, desc)) {
dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
// mem is nullptr if out_arr is view and desc is DNNL format.
// need to Reorder2Default before calling CreateDNNLMem
CHECK(mem != nullptr);
return dnnl_output_t(OutDataOp::Noop, mem);
} else if (kWriteInplace == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::CopyBack, tmp);
} else if (kWriteTo == req) {
dnnl::memory* mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
if (nullptr == mem) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::CopyBack, tmp);
}
return dnnl_output_t(OutDataOp::Noop, mem);
}
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::Noop, tmp);
}
dnnl_output_t CreateDNNLWeightGrad(const NDArray& out_arr,
const dnnl::memory::desc& desc,
OpReqType req) {
if (kAddTo == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::AddBack, tmp);
} else if (kWriteInplace == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::CopyBack, tmp);
} else {
dnnl::memory* mem = nullptr;
if (IsDefaultFormat(desc)) {
mem = const_cast<NDArray&>(out_arr).CreateDNNLData(desc);
}
if (mem == nullptr) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return dnnl_output_t(OutDataOp::CopyBack, tmp);
} else {
return dnnl_output_t(OutDataOp::Noop, mem);
}
}
}
void CommitOutput(const NDArray& arr, const dnnl_output_t& res) {
if (res.first == CopyBack) {
const_cast<NDArray&>(arr).CopyFrom(*res.second);
} else if (res.first == AddBack) {
auto res_memory = res.second;
auto target_pd = arr.GetDNNLData()->get_desc();
auto mem = arr.GetDNNLData(res.second->get_desc());
if (mem == nullptr) {
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
DNNLMemoryCopy(*res_memory, tmp_memory);
res_memory = tmp_memory;
mem = arr.GetDNNLData();
}
op::DNNLSum(*mem, *res_memory, *mem);
}
}
const dnnl::memory* GetWeights(const NDArray& arr, int num_groups) {
const auto type = get_dnnl_type(arr.dtype());
auto tz = dnnl::memory::dims{0};
auto format_tag = dnnl::memory::format_tag::undef;
auto engine = CpuEngine::Get()->get_engine();
const int ndim = arr.shape().ndim();
int O = 0, I = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
if (ndim == 2) {
tz = dnnl::memory::dims{arr.shape()[O], arr.shape()[I]};
format_tag = dnnl::memory::format_tag::oi;
} else if (ndim == 3) {
tz = num_groups > 1 ?
dnnl::memory::dims{
num_groups, arr.shape()[O] / num_groups, arr.shape()[I], arr.shape()[H]} :
dnnl::memory::dims{arr.shape()[O], arr.shape()[I], arr.shape()[H]};
format_tag = num_groups > 1 ? dnnl::memory::format_tag::goiw : dnnl::memory::format_tag::oiw;
} else if (ndim == 4) {
tz = num_groups > 1 ?
dnnl::memory::dims{num_groups,
arr.shape()[O] / num_groups,
arr.shape()[I],
arr.shape()[H],
arr.shape()[W]} :
dnnl::memory::dims{arr.shape()[O], arr.shape()[I], arr.shape()[H], arr.shape()[W]};
format_tag = num_groups > 1 ? dnnl::memory::format_tag::goihw : dnnl::memory::format_tag::oihw;
} else if (ndim == 5) {
tz = num_groups > 1 ?
dnnl::memory::dims{num_groups,
arr.shape()[O] / num_groups,
arr.shape()[I],
arr.shape()[D],
arr.shape()[H],
arr.shape()[W]} :
dnnl::memory::dims{
arr.shape()[O], arr.shape()[I], arr.shape()[D], arr.shape()[H], arr.shape()[W]};
format_tag =
num_groups > 1 ? dnnl::memory::format_tag::goidhw : dnnl::memory::format_tag::oidhw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
}
const auto md = dnnl::memory::desc{tz, type, format_tag};
return arr.GetDNNLData(md);
}
const dnnl::memory* GetWeights(const NDArray& arr,
const dnnl::memory::desc& target_desc,
int num_groups) {
const dnnl::memory* mem = arr.GetDNNLData(target_desc);
// If the weight array already uses the target layout, simply return it directly.
if (mem)
return mem;
mem = GetWeights(arr, num_groups);
if (mem == nullptr)
mem = arr.GetDNNLDataReorder(target_desc);
if (mem->get_desc() == target_desc)
return mem;
auto ret = TmpMemMgr::Get()->Alloc(target_desc);
std::unordered_map<int, dnnl::memory> args({{DNNL_ARG_FROM, *mem}, {DNNL_ARG_TO, *ret}});
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*mem, *ret), args);
return ret;
}
// default: block and dims' stride increase monotonically
// dnnl: 1.winograd 2.rnn packed 3. block and dims'stride is not increase monotonically
bool IsDNNL(const dnnl::memory::desc& desc) {
bool rslt = true;
if (desc.data.format_kind == dnnl_blocked) {
if (desc.data.format_desc.blocking.inner_nblks == 0) {
int i = 0;
for (i = 0; i < desc.data.ndims - 1; i++) {
if (desc.data.format_desc.blocking.strides[i] <
desc.data.format_desc.blocking.strides[i + 1]) {
break;
}
}
if (i == desc.data.ndims - 1) {
rslt = false;
}
}
}
return rslt;
}
dnnl_format_tag_t GetDefaultFormat(int num_dims) {
switch (num_dims) {
case 1:
return dnnl_a;
case 2:
return dnnl_ab;
case 3:
return dnnl_abc;
case 4:
return dnnl_abcd;
case 5:
return dnnl_abcde;
case 6:
return dnnl_abcdef;
case 7:
return dnnl_abcdefg;
case 8:
return dnnl_abcdefgh;
case 9:
return dnnl_abcdefghi;
case 10:
return dnnl_abcdefghij;
case 11:
return dnnl_abcdefghijk;
case 12:
return dnnl_abcdefghijkl;
default:
LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for oneDNN";
return dnnl_format_tag_undef;
}
}
dnnl_format_tag_t GetPermutedFormat(int num_dims) {
switch (num_dims) {
case 1:
return dnnl_a;
case 2:
return dnnl_ba;
case 3:
return dnnl_acb;
case 4:
return dnnl_abdc;
case 5:
return dnnl_abced;
case 6:
return dnnl_abcdfe;
case 7:
return dnnl_abcdegf;
case 8:
return dnnl_abcdefhg;
case 9:
return dnnl_abcdefgih;
case 10:
return dnnl_abcdefghji;
case 11:
return dnnl_abcdefghikj;
case 12:
return dnnl_abcdefghijlk;
default:
LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for oneDNN";
return dnnl_format_tag_undef;
}
}
dnnl_format_tag_t GetDefaultFormat(const dnnl::memory::desc& desc) {
return GetDefaultFormat(desc.data.ndims);
}
bool IsDefaultFormat(const dnnl::memory::desc& desc) {
bool rslt = false;
if (desc.data.format_kind == dnnl_blocked) {
if (desc.data.format_desc.blocking.inner_nblks == 0) {
int i = 0;
for (i = 0; i < desc.data.ndims - 1; i++) {
if (desc.data.format_desc.blocking.strides[i] <
desc.data.format_desc.blocking.strides[i + 1]) {
break;
}
}
if (i == desc.data.ndims - 1) {
rslt = true;
}
}
}
return rslt;
}
dnnl::memory::desc GetDesc(const dnnl::memory::desc& desc, const dnnl_format_tag_t& format) {
dnnl::memory::dims dims(desc.data.ndims);
for (size_t i = 0; i < dims.size(); i++)
dims[i] = desc.data.dims[i];
dnnl::memory::format_tag cpp_format = static_cast<dnnl::memory::format_tag>(format);
dnnl::memory::data_type cpp_type = static_cast<dnnl::memory::data_type>(desc.data.data_type);
return dnnl::memory::desc(dims, cpp_type, cpp_format);
}
// reorder dnnl src to dst format dtype
void ReorderTo(const dnnl::memory* src, const dnnl::memory* dst) {
dnnl::stream s(CpuEngine::Get()->get_engine());
auto new_src = *src;
auto new_dst = *dst;
dnnl::reorder(new_src, new_dst).execute(s, new_src, new_dst);
}
template <typename Compute, typename AttrState>
void FallBackCompute(Compute fn,
const AttrState& attrs_states,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
std::vector<TBlob> in_blobs(inputs.size());
std::vector<NDArray> in_bufs;
std::vector<OpReqType> new_req = req;
for (size_t i = 0; i < in_blobs.size(); i++) {
// If the input data isn't stored in the default format, we shouldn't
// call data() directly, which will change the layout of the NDArray.
// Instead, we should save the converted data in another NDArray.
// TODO(zhengda) we should use temp space to save the converted data.
if (inputs[i].IsDefaultData() && inputs[i].dtype() != mshadow::kBfloat16) {
in_blobs[i] = inputs[i].data();
} else {
if (in_bufs.empty())
in_bufs.reserve(inputs.size());
if (inputs[i].dtype() != mshadow::kBfloat16) {
in_bufs.push_back(inputs[i].Reorder2Default());
} else {
in_bufs.push_back(inputs[i].Reorder2DefaultFloatFormat());
}
in_blobs[i] = in_bufs.back().data();
}
}
DNNLStream::Get()->Submit();
std::vector<TBlob> out_blobs(outputs.size());
std::vector<NDArray> temp_src, temp_dst;
std::vector<NDArray> temp_bf16_src, temp_bf16_dst;
for (size_t i = 0; i < out_blobs.size(); i++) {
NDArray output = outputs[i];
// for bf16, fisrt change it to f32
if (outputs[i].dtype() == mshadow::kBfloat16) {
NDArray temp = outputs[i].Reorder2DefaultFloatFormat();
temp_bf16_src.emplace_back(temp);
temp_bf16_dst.emplace_back(outputs[i]);
output = temp;
if (req[i] == kWriteInplace) {
new_req[i] = kWriteTo;
}
} else {
// ensure output does not use dnnl mem.
// for inplace, we already converted & copied input above.
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) {
const_cast<NDArray&>(output).InvalidateDNNLData();
if (req[i] == kWriteInplace) {
new_req[i] = kWriteTo;
}
} else if (req[i] == kAddTo && output.IsDNNLData()) {
NDArray temp = outputs[i].Reorder2Default();
temp_src.emplace_back(temp);
temp_dst.emplace_back(outputs[i]);
output = temp;
}
}
CHECK(output.IsDefaultData());
out_blobs[i] = output.data();
}
fn(attrs_states, ctx, in_blobs, new_req, out_blobs);
for (size_t i = 0, bf16_pos = 0; i < out_blobs.size(); i++) {
if (outputs[i].dtype() == mshadow::kBfloat16) {
auto src_mem = temp_bf16_src[bf16_pos].GetDNNLData();
auto dst_mem = temp_bf16_dst[bf16_pos].GetDNNLData();
bf16_pos++;
ReorderTo(src_mem, dst_mem);
} else if (req[i] == kAddTo && outputs[i].IsDNNLData()) {
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
}
}
}
template <typename DType>
void print_diff(const mxnet::NDArray& arr1, const mxnet::NDArray& arr2) {
DType* data1 = reinterpret_cast<DType*>(arr1.data().dptr_);
DType* data2 = reinterpret_cast<DType*>(arr2.data().dptr_);
for (size_t i = 0; i < arr1.shape().Size(); i++)
std::cout << data1[i] - data2[i] << ", ";
std::cout << std::endl;
}
template <typename DType>
static bool SimilarArray(const mxnet::NDArray& arr1,
const mxnet::NDArray& arr2,
DType rtol,
DType atol) {
if (arr1.shape().Size() != arr2.shape().Size())
return false;
// This function should be used outside an DNNL operator.
// There shouldn't be any operators in the stream.
CHECK(!DNNLStream::Get()->HasOps());
// We need to reorder data in the arrays to the default layout.
// But we shouldn't reorder data in the original array.
NDArray buf1, buf2;
if (arr1.IsDNNLData()) {
buf1 = NDArray(arr1.shape(), arr1.ctx(), false, arr1.dtype());
auto mem = arr1.GetDNNLData();
buf1.CopyFrom(*mem);
}
if (arr2.IsDNNLData()) {
buf2 = NDArray(arr2.shape(), arr2.ctx(), false, arr2.dtype());
auto mem = arr2.GetDNNLData();
buf2.CopyFrom(*mem);
}
DNNLStream::Get()->Submit();
DType* data1 =
reinterpret_cast<DType*>(arr1.IsDNNLData() ? buf1.data().dptr_ : arr1.data().dptr_);
DType* data2 =
reinterpret_cast<DType*>(arr2.IsDNNLData() ? buf2.data().dptr_ : arr2.data().dptr_);
std::atomic<bool> success(true);
#pragma omp parallel for
#ifdef _MSC_VER
for (int64_t i = 0; i < arr1.shape().Size(); i++)
#else
for (size_t i = 0; i < arr1.shape().Size(); i++)
#endif
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wabsolute-value"
if (std::abs(data1[i] - data2[i]) > atol + rtol * std::abs(data2[i]))
success.store(false);
#pragma clang diagnostic pop
}
return success.load();
}
template void FallBackCompute(void (*)(nnvm::NodeAttrs const&,
OpContext const&,
std::vector<TBlob, std::allocator<TBlob> > const&,
std::vector<OpReqType, std::allocator<OpReqType> > const&,
std::vector<TBlob, std::allocator<TBlob> > const&),
nnvm::NodeAttrs const&,
OpContext const&,
std::vector<NDArray, std::allocator<NDArray> > const&,
std::vector<OpReqType, std::allocator<OpReqType> > const&,
std::vector<NDArray, std::allocator<NDArray> > const&);
template void FallBackCompute(void (*)(OpStatePtr const&,
OpContext const&,
std::vector<TBlob, std::allocator<TBlob> > const&,
std::vector<OpReqType, std::allocator<OpReqType> > const&,
std::vector<TBlob, std::allocator<TBlob> > const&),
OpStatePtr const&,
OpContext const&,
std::vector<NDArray, std::allocator<NDArray> > const&,
std::vector<OpReqType, std::allocator<OpReqType> > const&,
std::vector<NDArray, std::allocator<NDArray> > const&);
void OpCheck::Init(const std::vector<mxnet::NDArray>& inputs_,
const std::vector<mxnet::NDArray>& outputs_) {
auto ctx = inputs_[0].ctx();
CHECK(!DNNLStream::Get()->HasOps());
for (size_t i = 0; i < inputs_.size(); i++) {
NDArray data = inputs_[i];
inputs.emplace_back(data.shape(), ctx, false, data.dtype());
if (data.IsDNNLData() && data.IsView())
data = data.Reorder2Default();
auto mem = data.GetDNNLData();
inputs[i].CopyFrom(*mem);
}
for (size_t i = 0; i < outputs_.size(); i++) {
outputs.emplace_back(outputs_[i].shape(), ctx, false, outputs_[i].dtype());
if (backward) {
auto mem = outputs_[i].GetDNNLData();
outputs[i].CopyFrom(*mem);
}
}
DNNLStream::Get()->Submit();
}
void OpCheck::Run(mxnet::FCompute fn,
const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<mxnet::NDArray>& inputs_,
const std::vector<mxnet::OpReqType>& req,
const std::vector<mxnet::NDArray>& outputs_) {
static auto& is_excluded = Op::GetAttr<bool>("TExcludeDNNLDebug");
if (is_excluded.get(attrs.op, false)) {
LOG(WARNING) << attrs.op->name << " not checked. TExcludeDNNLDebug flag present";
return;
}
std::vector<mxnet::TBlob> in_blobs(inputs.size());
for (size_t i = 0; i < in_blobs.size(); i++)
in_blobs[i] = inputs[i].data();
std::vector<mxnet::TBlob> out_blobs(outputs.size());
for (size_t i = 0; i < out_blobs.size(); i++)
out_blobs[i] = outputs[i].data();
fn(attrs, ctx, in_blobs, req, out_blobs);
if (dmlc::GetEnv("MXNET_ONEDNN_DEBUG", false))
LOG(INFO) << "test " << attrs.op->name;
size_t num = std::min(outputs.size(), outputs_.size());
num = std::min(num_checks, num);
for (size_t i = 0; i < num; i++) {
// We don't need to compare if it doesn't need to output data.
if (req[i] == kNullOp)
continue;
MSHADOW_TYPE_SWITCH(outputs[i].dtype(), DType, {
bool similar = SimilarArray<DType>(
outputs[i], outputs_[i], static_cast<DType>(1e-2), static_cast<DType>(1e-2));
if (!similar) {
LOG(ERROR) << attrs.op->name << " fails";
}
CHECK(similar);
});
}
}
void OpCheck::CopyResult(const std::vector<mxnet::NDArray>& outputs_,
const std::vector<size_t>& indice) {
CHECK(!DNNLStream::Get()->HasOps());
auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray>&>(outputs_);
for (auto i = indice.begin(); i != indice.end(); ++i) {
auto mem = outputs[*i].GetDNNLData();
non_const_outputs_[*i].CopyFrom(*mem);
}
DNNLStream::Get()->Submit();
}
bool DNNLStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
bool support_dnnl,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
for (int& v : *in_attrs)
if (v == -1)
v = kDefaultStorage;
DispatchMode wanted_mode;
#if MXNET_USE_ONEDNN == 1
if (dev_mask == mshadow::cpu::kDevMask && !DNNLEnvSet())
wanted_mode = DispatchMode::kFComputeFallback;
else if (dev_mask == mshadow::cpu::kDevMask && support_dnnl)
wanted_mode = DispatchMode::kFComputeEx;
else
#endif
wanted_mode = DispatchMode::kFCompute;
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched =
op::storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode);
}
if (!dispatched) {
dispatched = op::dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
inline static const std::vector<NDArray> GetDNNLInputArray(const std::vector<NDArray>& inputs) {
std::vector<NDArray> ret;
ret.reserve(inputs.size());
for (const auto& in : inputs) {
if (in.IsView() && in.IsDNNLData()) {
ret.push_back(in.Reorder2Default());
} else {
ret.push_back(in);
}
}
return ret;
}
void DNNLRun(mxnet::FComputeEx fn,
const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<mxnet::NDArray>& inputs,
const std::vector<mxnet::OpReqType>& req,
const std::vector<mxnet::NDArray>& outputs) {
if (CheckDNNLInputArrayIsView(inputs)) {
const auto dnnl_inputs = GetDNNLInputArray(inputs);
fn(attrs, ctx, dnnl_inputs, req, outputs);
} else {
fn(attrs, ctx, inputs, req, outputs);
}
}
void DNNLRun(FComputeExUnary fn,
const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const mxnet::NDArray& input,
const mxnet::OpReqType& req,
const mxnet::NDArray& output) {
auto dnnl_input = input;
if (input.IsView() && input.IsDNNLData()) {
dnnl_input = input.Reorder2Default();
fn(attrs, ctx, dnnl_input, req, output);
} else {
fn(attrs, ctx, input, req, output);
}
}
} // namespace mxnet
#endif