blob: 7c9a916690cba2825aaa783239822ddc70e7f475 [file] [log] [blame]
/*******************************************************************************
* Copyright 2016-2017 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file dnnl_base-inl.h
* \brief
* \author young.jin.kim@intel.com
* ashok.emani@intel.com
* deepthi.karkada@intel.com
* louis.feng@intel.com
* adam.d.straw@intel.com
* zhengda1936@gmail.com
*
*******************************************************************************/
#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BASE_INL_H_
#define MXNET_OPERATOR_NN_DNNL_DNNL_BASE_INL_H_
#if MXNET_USE_ONEDNN == 1
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dnnl.hpp"
#include "mxnet/graph_attr_types.h"
#include "mxnet/ndarray.h"
#include "mxnet/op_attr_types.h"
#include "mxnet/resource.h"
#define DNNL_REAL_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: { \
typedef float DType; \
{ __VA_ARGS__ } \
} break; \
case mshadow::kBfloat16: { \
typedef mshadow::bfloat::bf16_t DType; \
{ __VA_ARGS__ } \
} break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
namespace mxnet {
// ===== CpuEngine =======================================
// cpu_engine singleton
class CpuEngine {
public:
static CpuEngine* Get() {
// I's thread-safe in C++11.
// ensure same dnnl engine is used across threads
static CpuEngine myInstance;
return &myInstance;
}
CpuEngine(CpuEngine const&) = delete; // Copy construct
CpuEngine(CpuEngine&&) = delete; // Move construct
CpuEngine& operator=(CpuEngine const&) = delete; // Copy assign
CpuEngine& operator=(CpuEngine&&) = delete; // Move assign
dnnl::engine& get_engine() {
return _cpu_engine;
}
protected:
CpuEngine() : _cpu_engine(dnnl::engine::kind::cpu, 0) {}
~CpuEngine() {}
private:
dnnl::engine _cpu_engine;
};
// type enumerator
template <typename T>
struct data_type_enum {};
template <>
struct data_type_enum<float> {
enum { type = static_cast<unsigned int>(dnnl::memory::data_type::f32) };
};
template <>
struct data_type_enum<mshadow::bfloat::bf16_t> {
enum { type = static_cast<unsigned int>(dnnl::memory::data_type::bf16) };
};
template <>
struct data_type_enum<int32_t> {
enum { type = static_cast<unsigned int>(dnnl::memory::data_type::s32) };
};
template <>
struct data_type_enum<int8_t> {
enum { type = static_cast<unsigned int>(dnnl::memory::data_type::s8) };
};
template <>
struct data_type_enum<uint8_t> {
enum { type = static_cast<unsigned int>(dnnl::memory::data_type::u8) };
};
static inline bool SupportDNNLArray(int dtype, const mxnet::TShape& shape) {
int ndim = shape.ndim();
bool support = ndim == 1 || ndim == 2 || ndim == 4;
support = support &&
(dtype == mshadow::kFloat32 || dtype == mshadow::kInt32 || dtype == mshadow::kInt8 ||
dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16);
return support;
}
static inline bool SupportStorageDNNL(int stype) {
return stype == kDefaultStorage;
}
static inline bool SupportDNNL(int dtype, const mxnet::TShape& shape) {
int ndim = shape.ndim();
if (ndim == 0 || shape.Size() == 0) {
// DNNL currently does not support 0-dim Tensor and 0-size Tensor
return false;
}
return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
(ndim == 1 || ndim == 2 || ndim == 4);
}
static inline bool IsDNNLType(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == mshadow::kUint8 ||
dtype == mshadow::kBfloat16;
}
static inline bool SupportDNNL(const NDArray& input) {
return SupportDNNL(input.dtype(), input.shape()) && SupportStorageDNNL(input.storage_type());
}
static inline bool DNNLEnvSet() {
static bool is_dnnl_enabled = dmlc::GetEnv("MXNET_ONEDNN_ENABLED", true);
return is_dnnl_enabled;
}
static inline int GetDNNLCacheSize() {
static int dnnl_cache_size = dmlc::GetEnv("MXNET_ONEDNN_CACHE_NUM", -1);
return dnnl_cache_size;
}
// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size during runtime
template <typename S, typename I, typename H>
static typename std::unordered_map<S, I, H>::iterator AddToCache(std::unordered_map<S, I, H>* cache,
const S& key,
const I& item) {
int dnnl_cache_size = GetDNNLCacheSize();
if (dnnl_cache_size != -1 && static_cast<int>(cache->size()) > dnnl_cache_size)
cache->erase(cache->begin());
auto ins_return = cache->insert(std::pair<S, I>(key, item));
CHECK(ins_return.second);
return ins_return.first;
}
/*
* This is to align address to a certain alignment.
*/
void* AlignMem(void* mem, size_t size, size_t alignment, size_t* space);
namespace op {
struct ActivationParam;
struct LeakyReLUParam;
struct ConvolutionParam;
struct DeconvolutionParam;
struct SoftmaxParam;
struct MaskedSoftmaxParam;
struct SoftmaxOutputParam;
struct ReshapeParam;
struct LayerNormParam;
bool SupportDNNLAct(const ActivationParam& param);
bool SupportDNNLAct(const ActivationParam& param, const NDArray& input);
bool SupportDNNLLeakyRelu(const LeakyReLUParam& param);
bool SupportDNNLLeakyRelu(const LeakyReLUParam& param, const NDArray& input);
bool SupportQuantizedDNNLAct(const ActivationParam& param);
bool SupportDNNLConv(const ConvolutionParam& params, const NDArray& input);
bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input);
bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& input, const NDArray& output);
bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& input, const NDArray& output);
bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param,
const std::vector<NDArray>& input,
const NDArray& output);
bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param);
bool SupportDNNLTranspose(const NDArray& data);
bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
bool SupportDNNLSplit(const NDArray& input);
bool SupportDNNLStack(const std::vector<NDArray>& inputs);
bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
bool SupportDNNLTanh(const NDArray& input, const NDArray& output);
} // namespace op
static int GetTypeSize(int dtype) {
int size = -1;
MSHADOW_TYPE_SWITCH(dtype, DType, { size = sizeof(DType); });
return size;
}
static inline size_t GetArraySize(const NDArray& arr) {
if (arr.IsDNNLData()) {
return arr.GetDNNLData()->get_desc().get_size();
}
return arr.shape().Size() * GetTypeSize(arr.dtype());
}
static inline dnnl::memory::data_type get_dnnl_type(int dtype) {
switch (dtype) {
case mshadow::kFloat32:
return dnnl::memory::data_type::f32;
case mshadow::kBfloat16:
return dnnl::memory::data_type::bf16;
case mshadow::kInt32:
return dnnl::memory::data_type::s32;
case mshadow::kInt8:
return dnnl::memory::data_type::s8;
case mshadow::kUint8:
return dnnl::memory::data_type::u8;
default:
LOG(FATAL) << "unknown type for oneDNN :" << static_cast<int>(dtype);
return dnnl::memory::data_type::undef;
}
}
template <typename T>
static inline dnnl::memory::data_type get_dnnl_type() {
return static_cast<dnnl::memory::data_type>(data_type_enum<T>::type);
}
static inline dnnl_data_type_t get_dnnl_type_t(int dtype) {
return static_cast<dnnl_data_type_t>(get_dnnl_type(dtype));
}
template <typename T>
static inline dnnl_data_type_t get_dnnl_type_t() {
return static_cast<dnnl_data_type_t>(data_type_enum<T>::type);
}
static inline int get_mxnet_type(dnnl_data_type_t dtype) {
auto dnnl_dtype = static_cast<dnnl::memory::data_type>(dtype);
switch (dnnl_dtype) {
case dnnl::memory::data_type::f32:
return mshadow::kFloat32;
case dnnl::memory::data_type::bf16:
return mshadow::kBfloat16;
case dnnl::memory::data_type::s32:
return mshadow::kInt32;
case dnnl::memory::data_type::s8:
return mshadow::kInt8;
case dnnl::memory::data_type::u8:
return mshadow::kUint8;
default:
LOG(FATAL) << "unknown oneDNN data type";
return mshadow::kFloat32;
}
}
static inline size_t GetMemDescSize(const dnnl::memory::desc& md) {
if (md.data.ndims == 0)
return 0;
size_t ret = 1;
for (int i = 0; i < md.data.ndims; i++) {
ret *= md.data.dims[i];
}
ret *= mshadow::mshadow_sizeof(get_mxnet_type(md.data.data_type));
return ret;
}
inline static dnnl::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1) {
int ndim = arr.shape().ndim();
dnnl::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++)
dims[i] = arr.shape()[i];
return dnnl::memory::desc{dims, get_dnnl_type(dtype), dnnl::memory::format_tag::any};
}
inline static bool ChooseBRGEMMImpl(const dnnl::memory::dims& weight_dims, size_t batch_size) {
// Conditions based on measurement results done on CLX8280
// https://github.com/apache/incubator-mxnet/pull/20533
return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
}
inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
int ndim = arr.shape().ndim();
dnnl::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++)
dims[i] = arr.shape()[i];
auto format = dnnl::memory::format_tag::any;
// for batch 256 alexnet benchmark test
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
format = dnnl::memory::format_tag::ab;
}
}
return dnnl::memory::desc{dims, get_dnnl_type(dtype), format};
}
inline static dnnl::memory::desc GetWeightDesc(const NDArray& arr,
int num_groups,
bool quantized = false) {
int dtype = quantized ? mshadow::kInt8 : arr.dtype();
if (num_groups == 1) {
return GetMemDesc(arr, dtype);
} else {
const auto ndim = arr.shape().ndim();
CHECK((ndim == 3) || (ndim == 4) || (ndim == 5))
<< "oneDNN weight currently supports 3d or 4d or 5d layout";
auto tz = dnnl::memory::dims{0};
int N = 0, C = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
switch (ndim) {
case 3:
tz = dnnl::memory::dims{
num_groups, arr.shape()[N] / num_groups, arr.shape()[C], arr.shape()[H]};
break;
case 4:
tz = dnnl::memory::dims{num_groups,
arr.shape()[N] / num_groups,
arr.shape()[C],
arr.shape()[H],
arr.shape()[W]};
break;
case 5:
tz = dnnl::memory::dims{num_groups,
arr.shape()[N] / num_groups,
arr.shape()[C],
arr.shape()[D],
arr.shape()[H],
arr.shape()[W]};
}
return dnnl::memory::desc{tz, get_dnnl_type(dtype), dnnl::memory::format_tag::any};
}
}
inline static bool CheckDNNLInputArrayIsView(const std::vector<NDArray>& inputs) {
for (const auto& in : inputs) {
if (in.IsView() && in.IsDNNLData()) {
return true;
}
}
return false;
}
typedef std::shared_ptr<dnnl::memory> dnnl_mem_ptr;
typedef std::shared_ptr<const dnnl::memory> dnnl_mem_const_ptr;
/*
* This is to manage the temporary memory provided by MXNet for operators.
* The temp memory is mainly used to keep the reordered data. In an operator, we
* may need multiple pieces of memory for them. But MXNet can only provide
* a single piece of memory. This class is to help break the temporary memory
* from MXNet to store the reordered data.
* The amount of temporary memory used in an operator depends on the layout of
* input arrays and the operator. It's difficult to calculate it manually, so
* the class also estimate the amount of memory automatically.
*/
class TmpMemMgr {
// This points to the memory buffer where we can allocate temp memory.
char* curr_mem;
// The total size of the temp memory.
size_t mem_size;
// This contains the current available memory size.
size_t curr_size;
// This estimate the required temp memory size in an operator.
size_t est_size;
const size_t alignment = kDNNLAlign;
public:
static TmpMemMgr* Get() {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local TmpMemMgr mgr;
#else
static MX_THREAD_LOCAL TmpMemMgr mgr;
#endif
return &mgr;
}
TmpMemMgr() {
Reset();
est_size = 0;
mem_size = 0;
}
void Reset() {
curr_mem = nullptr;
curr_size = 0;
// We don't reset est_size and mem_size because est_size contains the
// estimated temp memory size from the last run and mem_size contains the
// memroy size allocated in the last run.
}
void Init(const Resource& r) {
// If the last time, if we estimate that we need more memory, we should the
// larger memory size.
mem_size = std::max(mem_size, est_size);
if (mem_size > 0) {
// Let's allocate some extra memory. If we don't use some of them all the time,
// the OS won't physically allocate pages for them any way.
this->curr_size = mem_size * 2;
this->curr_mem = static_cast<char*>(r.get_host_space_internal(this->curr_size));
}
// reset est_size, so we can start to estimate the temp memory size.
this->est_size = 0;
}
dnnl::memory* Alloc(const dnnl::memory::desc& md);
};
typedef std::unordered_map<int, dnnl::memory> dnnl_args_map_t;
class DNNLStream {
std::vector<std::pair<dnnl::primitive, dnnl_args_map_t> > net_prim_args;
// Here we hold all memory related to the operators in the stream.
std::vector<std::shared_ptr<const dnnl::memory> > mem_holder;
dnnl::stream s;
public:
static DNNLStream* Get();
DNNLStream() : s(CpuEngine::Get()->get_engine()) {}
void RegisterPrimArgs(const dnnl::primitive& prim, const dnnl_args_map_t& args) {
net_prim_args.emplace_back(prim, args);
}
void RegisterMem(std::shared_ptr<const dnnl::memory> mem) {
mem_holder.push_back(mem);
}
bool HasOps() const {
return !net_prim_args.empty();
}
/*
* After submitting dnnl operations for execution, we need to
* clean up memory held by the stream. However, sometimes users
* might want to separate dnnl execution and memory cleanup.
*/
void Submit(bool cleanup = true) {
if (!net_prim_args.empty()) {
for (auto& v : net_prim_args) {
v.first.execute(s, v.second);
}
net_prim_args.clear();
}
if (cleanup)
Cleanup();
}
void Cleanup() {
mem_holder.clear();
TmpMemMgr::Get()->Reset();
}
};
enum OutDataOp {
Noop,
CopyBack,
AddBack,
};
typedef std::pair<OutDataOp, dnnl::memory*> dnnl_output_t;
void DNNLMemoryCopy(const dnnl::memory& mem, const dnnl::memory* this_mem);
/*
* Here we want to get DNNL memory whose desc is exactly the same as
* the given one. operator== can't guarantee that. == can return true even if
* the formats are different. I need to double check its format.
*/
static inline dnnl::memory* GetDNNLExact(const dnnl::memory* mem, const dnnl::memory::desc& desc) {
dnnl::memory::desc src_desc = mem->get_desc();
if (desc == src_desc) {
return const_cast<dnnl::memory*>(mem);
} else {
std::shared_ptr<dnnl::memory> ret(
new dnnl::memory(desc, CpuEngine::Get()->get_engine(), mem->get_data_handle()));
DNNLStream::Get()->RegisterMem(ret);
return ret.get();
}
}
/*
* These two functions try to create DNNL memory in an NDArray based on `req'.
* The difference is that the first function can create DNNL memory with
* special layouts in an NDArray, while the second one can only create DNNL
* memory with default layouts.
* Also an optional in_arr parameter can be passed in the first function with
* the kWriteInPlace req to validate if dnnl can support write in place;
* otherwise new memory will be written to an copied back onto out_arr.
* If these two functions are used, we have to call CommitOutput to write
* the output back to the output NDArray.
*/
dnnl_output_t CreateDNNLMem(const NDArray& out_arr,
const dnnl::memory::desc& desc,
OpReqType req,
const NDArray* in_arr = nullptr);
dnnl_output_t CreateDNNLWeightGrad(const NDArray& out_arr,
const dnnl::memory::desc& desc,
OpReqType req);
/* This function has to be used with one of the functions above. */
void CommitOutput(const NDArray& arr, const dnnl_output_t& res);
const dnnl::memory* GetWeights(const NDArray& arr, int num_groups);
const dnnl::memory* GetWeights(const NDArray& arr,
const dnnl::memory::desc& target_md,
int num_groups);
bool IsDefaultFormat(const dnnl::memory::desc& desc);
bool IsDNNL(const dnnl::memory::desc& desc);
dnnl_format_tag_t GetDefaultFormat(const dnnl::memory::desc& md);
dnnl_format_tag_t GetDefaultFormat(int num_dims);
dnnl::memory::desc GetDesc(const dnnl::memory::desc& md, const dnnl_format_tag_t& format);
inline bool same_shape(const mxnet::TShape& shape, const dnnl_dims_t dims, int ndims) {
if (shape.ndim() != ndims)
return false;
for (int i = 0; i < ndims; i++)
if (shape[i] != dims[i])
return false;
return true;
}
inline bool same_shape(const dnnl::memory::desc& desc1, const dnnl::memory::desc& desc2) {
if (desc1.data.ndims != desc2.data.ndims)
return false;
for (int i = 0; i < desc1.data.ndims; i++)
if (desc1.data.dims[i] != desc2.data.dims[i])
return false;
return true;
}
inline bool same_shape(const mxnet::TShape& shape, int dtype, const dnnl::memory::desc& desc) {
return same_shape(shape, desc.data.dims, desc.data.ndims) &&
get_dnnl_type(dtype) == desc.data.data_type;
}
/*
* There is a large overhead of getting dnnl::memory::desc from
* dnnl::memory. This class is created to cache the metadata of dnnl memory
* to provide a much more lightweight method to access them.
*/
class DNNLMemory {
std::shared_ptr<dnnl::memory> mem;
dnnl::memory::desc desc;
size_t size; // The number of bytes.
public:
DNNLMemory(dnnl::memory::desc md, void* addr) : desc(md) {
mem.reset(new dnnl::memory(md, CpuEngine::Get()->get_engine(), addr));
size = desc.get_size();
}
explicit DNNLMemory(std::shared_ptr<dnnl::memory> mem) : desc(mem->get_desc()) {
this->mem = mem;
size = desc.get_size();
}
void SetDataHandle(void* handle) {
mem->set_data_handle(handle);
}
void* GetDataHandle() const {
return mem->get_data_handle();
}
std::shared_ptr<dnnl::memory> GetMem() const {
return mem;
}
dnnl::memory* GetRaw() const {
return mem.get();
}
size_t GetSize() const {
return size;
}
dnnl::memory::desc GetDesc() const {
return mem->get_desc();
}
dnnl::memory::desc GetDesc(
dnnl_format_tag_t format,
dnnl::memory::data_type data_type = dnnl::memory::data_type::undef) const {
dnnl::memory::dims dims(desc.data.dims, desc.data.dims + desc.data.ndims);
dnnl::memory::data_type cpp_type =
(data_type == dnnl::memory::data_type::undef) ?
static_cast<dnnl::memory::data_type>(desc.data.data_type) :
data_type;
dnnl::memory::desc data_md(dims, cpp_type, static_cast<dnnl::memory::format_tag>(format));
return data_md;
}
dnnl_format_tag_t GetDefaultFormat() const {
return mxnet::GetDefaultFormat(desc);
}
bool IsDNNL() const {
return mxnet::IsDNNL(desc);
}
bool SameFormat(dnnl::memory::desc md) const {
return mem->get_desc() == md;
}
bool SameFormat(const mxnet::TShape& shape, int dtype) const {
return same_shape(shape, dtype, desc);
}
void ReorderTo(dnnl::memory* other) const {
dnnl::stream s(CpuEngine::Get()->get_engine());
dnnl::reorder(*mem, *other).execute(s, *mem, *other);
}
};
// reorder dnnl src to dst format dtype
void ReorderTo(const dnnl::memory* src, const dnnl::memory* dst);
template <typename Compute, typename AttrState>
void FallBackCompute(Compute fn,
const AttrState& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
/*
* This class is used to check the correctness of DNNL operators.
*/
class OpCheck {
std::vector<mxnet::NDArray> inputs;
std::vector<mxnet::NDArray> outputs;
bool backward;
size_t num_checks;
public:
OpCheck(bool backward, size_t num_checks) {
this->backward = backward;
this->num_checks = num_checks;
}
void Init(const std::vector<mxnet::NDArray>& inputs_,
const std::vector<mxnet::NDArray>& outputs_);
void Run(mxnet::FCompute fn,
const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<mxnet::NDArray>& inputs_,
const std::vector<mxnet::OpReqType>& req,
const std::vector<mxnet::NDArray>& outputs_);
void CopyResult(const std::vector<mxnet::NDArray>& outputs_, const std::vector<size_t>& indice);
};
bool DNNLStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
bool support_dnnl,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs);
#define DNNL_OPCHECK_INIT(backward, num_checks, inputs, outputs) \
static bool debug = dmlc::GetEnv("MXNET_ONEDNN_DEBUG", false); \
OpCheck check(backward, num_checks); \
if (debug) \
check.Init(inputs, outputs);
#define DNNL_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs) \
if (debug) \
check.Run(fn, attrs, ctx, inputs, req, outputs);
#define DNNL_OPCHECK_COPY_RESULT(outputs, indice) \
if (debug) \
check.CopyResult(outputs, indice);
struct DNNLPostEltwiseParam {
dnnl::algorithm alg = dnnl::algorithm::undef;
float scale = 1.f;
float alpha = 0.f;
float beta = 1.f;
};
void DNNLRun(mxnet::FComputeEx fn,
const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const std::vector<mxnet::NDArray>& inputs_,
const std::vector<mxnet::OpReqType>& req,
const std::vector<mxnet::NDArray>& outputs_);
using FComputeExUnary = std::function<void(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const NDArray& input,
const OpReqType& req,
const NDArray& output)>;
void DNNLRun(FComputeExUnary fn,
const nnvm::NodeAttrs& attrs,
const mxnet::OpContext& ctx,
const mxnet::NDArray& inputs_,
const mxnet::OpReqType& req,
const mxnet::NDArray& outputs_);
} // namespace mxnet
#endif
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BASE_INL_H_