[operator] Integrate oneDNN matmul primitive to mxnet dot operator (#20911)

* [operator] Integrate oneDNN matmul primitive to mxnet dot operator

* Fix numpy dot test

* Fix reviews
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h
index 7c9a916..97aacc1 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -198,6 +198,7 @@
                               const NDArray& output);
 bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param);
 bool SupportDNNLTranspose(const NDArray& data);
+bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output);
 bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
 bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
 bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
@@ -540,6 +541,7 @@
 
 dnnl_format_tag_t GetDefaultFormat(const dnnl::memory::desc& md);
 dnnl_format_tag_t GetDefaultFormat(int num_dims);
+dnnl_format_tag_t GetPermutedFormat(int num_dims);
 dnnl::memory::desc GetDesc(const dnnl::memory::desc& md, const dnnl_format_tag_t& format);
 
 inline bool same_shape(const mxnet::TShape& shape, const dnnl_dims_t dims, int ndims) {
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index 05fabd5..ec50da7 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -347,6 +347,38 @@
   }
 }
 
+dnnl_format_tag_t GetPermutedFormat(int num_dims) {
+  switch (num_dims) {
+    case 1:
+      return dnnl_a;
+    case 2:
+      return dnnl_ba;
+    case 3:
+      return dnnl_acb;
+    case 4:
+      return dnnl_abdc;
+    case 5:
+      return dnnl_abced;
+    case 6:
+      return dnnl_abcdfe;
+    case 7:
+      return dnnl_abcdegf;
+    case 8:
+      return dnnl_abcdefhg;
+    case 9:
+      return dnnl_abcdefgih;
+    case 10:
+      return dnnl_abcdefghji;
+    case 11:
+      return dnnl_abcdefghikj;
+    case 12:
+      return dnnl_abcdefghijlk;
+    default:
+      LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for oneDNN";
+      return dnnl_format_tag_undef;
+  }
+}
+
 dnnl_format_tag_t GetDefaultFormat(const dnnl::memory::desc& desc) {
   return GetDefaultFormat(desc.data.ndims);
 }
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
index 4117b17..b48afd1 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
@@ -38,9 +38,6 @@
 namespace mxnet {
 namespace op {
 
-enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
-enum DotOut { out = 0, out_min, out_max };
-
 struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
   bool transpose_a;
   bool transpose_b;
diff --git a/src/operator/nn/dnnl/dnnl_dot-inl.h b/src/operator/nn/dnnl/dnnl_dot-inl.h
new file mode 100644
index 0000000..b375872
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_dot-inl.h
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_dot-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_DOT_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_DOT_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <vector>
+
+#include "dnnl_base-inl.h"
+#include "operator/tensor/dot-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using dot_fwd_t    = dnnl::matmul;
+using dot_fwd_pd_t = dnnl::matmul::primitive_desc;
+
+typedef ParamOpSign<DotParam> DotSignature;
+
+class DNNLDotFwd {
+ public:
+  static DNNLDotFwd& GetCached(const DotParam& param,
+                               const std::vector<NDArray>& inputs,
+                               const std::vector<NDArray>& outputs,
+                               const bool isNumpy);
+
+  DNNLDotFwd(const DotParam& param,
+             const std::vector<NDArray>& inputs,
+             const std::vector<NDArray>& outputs,
+             const bool isNumpy);
+
+  void Execute(const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs,
+               const bool isNumpy);
+
+ private:
+  std::shared_ptr<dot_fwd_t> fwd;
+  std::shared_ptr<dot_fwd_pd_t> fwd_pd;
+};
+
+template <bool isNumpy>
+void DNNLDotForward(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<NDArray>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<NDArray>& outputs) {
+  DotParam param;
+  if (isNumpy) {
+    // NumPy version of dot operator does not support transpose flags.
+    param             = DotParam();
+    param.transpose_a = false;
+    param.transpose_b = false;
+  } else {
+    param = nnvm::get<DotParam>(attrs.parsed);
+  }
+  DNNLDotFwd& fwd = DNNLDotFwd::GetCached(param, inputs, outputs, isNumpy);
+  fwd.Execute(inputs, req, outputs, isNumpy);
+}
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_DOT_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_dot.cc b/src/operator/nn/dnnl/dnnl_dot.cc
new file mode 100644
index 0000000..3978646
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_dot.cc
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_dot.cc
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <unordered_map>
+
+#include "dnnl_dot-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output) {
+#if MXNET_USE_BLAS_MKL == 1
+  return false;
+#endif
+  return inputs[DotIn::lhs].shape().Size() > 1 && inputs[DotIn::rhs].shape().Size() > 1 &&
+         inputs[DotIn::lhs].shape().ndim() > 0 && inputs[DotIn::rhs].shape().ndim() > 0 &&
+         output.shape().Size() != 0 && output.shape().ndim() > 0 && output.shape().ndim() <= 12 &&
+         (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
+          inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
+}
+
+DNNLDotFwd& DNNLDotFwd::GetCached(const DotParam& param,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<NDArray>& outputs,
+                                  const bool isNumpy) {
+  using dot_fwd_map = std::unordered_map<DotSignature, DNNLDotFwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local dot_fwd_map fwds;
+#else
+  static MX_THREAD_LOCAL dot_fwd_map fwds;
+#endif
+
+  DotSignature key(param);
+  key.AddSign(inputs[DotIn::lhs]);
+  key.AddSign(inputs[DotIn::rhs]);
+  key.AddSign(outputs[DotOut::out]);
+  key.AddSign(static_cast<int>(isNumpy));
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    const DNNLDotFwd fwd(param, inputs, outputs, isNumpy);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+auto GetMemoryDesc(const NDArray& tensor, int firstDim, int secondDim, const bool transpose) {
+  return dnnl::memory::desc(
+      dnnl::memory::dims{firstDim, secondDim},
+      get_dnnl_type(tensor.dtype()),
+      transpose ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab);
+}
+
+DNNLDotFwd::DNNLDotFwd(const DotParam& param,
+                       const std::vector<NDArray>& inputs,
+                       const std::vector<NDArray>& outputs,
+                       const bool isNumpy) {
+  auto shapeLhs = inputs[DotIn::lhs].shape(), shapeRhs = inputs[DotIn::rhs].shape();
+  auto ndimLhs = shapeLhs.ndim(), ndimRhs = shapeRhs.ndim();
+  dnnl::memory::desc lhs_md, rhs_md, out_md;
+  // NumPy expects more than 2 dimensional rhs tensor as Ax...xKxN which is different than NDArray's
+  // KxAx...xN format. For NumPy shape in rhs memory descriptor is going to be Kx(A*...*N),
+  // similarly to NDArray, but for it to match the actual data there is an additional reorder
+  // needed, permuting the last two axes Ax...xKxN -> Ax...xNxK. For this data to match Kx(A*...*N)
+  // shape format_tag needs to be set to ba. Reorder described above is implemented in Execute
+  // function.
+  const bool differentNumpy = isNumpy && ndimRhs > 2;
+  const int smallDimLhs     = param.transpose_a ? shapeLhs[0] : shapeLhs[ndimLhs - 1];
+  const int bigDimLhs       = shapeLhs.Size() / smallDimLhs;
+  const int smallDimRhs     = param.transpose_b ?
+                              shapeRhs[ndimRhs - 1] :
+                              (differentNumpy ? shapeRhs[ndimRhs - 2] : shapeRhs[0]);
+  const int bigDimRhs = shapeRhs.Size() / smallDimRhs;
+
+  lhs_md = GetMemoryDesc(inputs[DotIn::lhs], bigDimLhs, smallDimLhs, param.transpose_a);
+  rhs_md = GetMemoryDesc(
+      inputs[DotIn::rhs], smallDimRhs, bigDimRhs, param.transpose_b || differentNumpy);
+  out_md = dnnl::memory::desc({bigDimLhs, bigDimRhs},
+                              get_dnnl_type(outputs[DotOut::out].dtype()),
+                              dnnl::memory::format_tag::any);
+  dnnl::matmul::desc fwd_desc(lhs_md, rhs_md, out_md);
+  fwd_pd = std::make_shared<dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
+  fwd    = std::make_shared<dot_fwd_t>(*fwd_pd);
+}
+
+void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& outputs,
+                         const bool isNumpy) {
+  auto engine = mxnet::CpuEngine::Get()->get_engine();
+  auto lhs    = dnnl::memory(
+      fwd_pd->src_desc(), engine, reinterpret_cast<void*>(inputs[DotIn::lhs].data().dptr_));
+  auto rhs     = dnnl::memory(fwd_pd->weights_desc(), engine);
+  auto ndimRhs = inputs[DotIn::rhs].shape().ndim();
+  if (isNumpy && ndimRhs > 2) {
+    // Necessity of this reorder is described in DNNLDotFwd constructor.
+    auto tmp_rhs = inputs[DotIn::rhs].GetDNNLData();
+    dnnl::memory::desc rhs_md(
+        dnnl::memory::dims(inputs[DotIn::rhs].shape().begin(), inputs[DotIn::rhs].shape().end()),
+        get_dnnl_type(inputs[DotIn::rhs].dtype()),
+        static_cast<dnnl::memory::format_tag>(GetPermutedFormat(ndimRhs)));
+    dnnl::memory tmp_rhs_dst(rhs_md, engine, rhs.get_data_handle());
+    const auto rhs_reorder_pd = dnnl::reorder::primitive_desc(*tmp_rhs, tmp_rhs_dst);
+    DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(rhs_reorder_pd),
+                                        {{DNNL_ARG_FROM, *tmp_rhs}, {DNNL_ARG_TO, tmp_rhs_dst}});
+  } else {
+    rhs.set_data_handle(reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_));
+  }
+  dnnl_output_t out_mem = CreateDNNLMem(
+      outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]);
+
+  dnnl_args_map_t args = {
+      {DNNL_ARG_SRC, lhs},
+      {DNNL_ARG_WEIGHTS, rhs},
+      {DNNL_ARG_DST, *out_mem.second},
+  };
+
+  DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
+  CommitOutput(outputs[DotOut::out], out_mem);
+  DNNLStream::Get()->Submit();
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/numpy/np_dot_forward.cc b/src/operator/numpy/np_dot_forward.cc
index 1c2da2d..8704860 100644
--- a/src/operator/numpy/np_dot_forward.cc
+++ b/src/operator/numpy/np_dot_forward.cc
@@ -22,7 +22,10 @@
  * \brief CPU Implementation of numpy-compatible dot
  */
 
-#include "./np_dot-inl.h"
+#include "np_dot-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_dot-inl.h"
+#endif
 
 namespace mxnet {
 namespace op {
@@ -101,6 +104,32 @@
   return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
 }
 
+#if MXNET_USE_ONEDNN == 1
+static void NumpyDotComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                 const OpContext& ctx,
+                                 const std::vector<NDArray>& inputs,
+                                 const std::vector<OpReqType>& req,
+                                 const std::vector<NDArray>& outputs) {
+  if (SupportDNNLDot(inputs, outputs[0])) {
+    DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    DNNLRun(DNNLDotForward<true>, attrs, ctx, inputs, req, outputs);
+    DNNL_OPCHECK_RUN(NumpyDotForward<cpu>, attrs, ctx, inputs, req, outputs);
+  } else {
+    FallBackCompute(NumpyDotForward<cpu>, attrs, ctx, inputs, req, outputs);
+  }
+}
+
+inline static bool NumpyDotStorageType(const nnvm::NodeAttrs& attrs,
+                                       const int dev_mask,
+                                       DispatchMode* dispatch_mode,
+                                       std::vector<int>* in_attrs,
+                                       std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+#endif
+
 NNVM_REGISTER_OP(_npi_dot)
     .describe(R"doc(Dot product of two arrays. Specifically,
 
@@ -134,6 +163,11 @@
                                 })
     .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
     .set_attr<FCompute>("FCompute<cpu>", NumpyDotForward<cpu>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<bool>("TIsDNNL", true)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", NumpyDotComputeExCPU)
+    .set_attr<FInferStorageType>("FInferStorageType", NumpyDotStorageType)
+#endif
     .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_dot"})
     .add_argument("a", "NDArray-or-Symbol", "First input")
     .add_argument("b", "NDArray-or-Symbol", "Second input");
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 7cd9fa9..dac690f 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -44,6 +44,9 @@
 namespace mxnet {
 namespace op {
 
+enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
+enum DotOut { out = 0, out_min, out_max };
+
 struct DotParam : public dmlc::Parameter<DotParam> {
   bool transpose_a;
   bool transpose_b;
@@ -109,50 +112,57 @@
   using namespace mshadow::expr;
   const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
   Stream<xpu>* s        = ctx.get_stream<xpu>();
-  CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_)
+  CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::lhs].type_flag_)
       << "Binary function only support input/output with the same type";
-  CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_)
+  CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::rhs].type_flag_)
       << "Binary function only support input/output with the same type";
-  CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64 ||
-        (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
+  CHECK(outputs[DotOut::out].type_flag_ == kFloat32 ||
+        outputs[DotOut::out].type_flag_ == kFloat64 ||
+        (outputs[DotOut::out].type_flag_ == kFloat16 &&
+         ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
       << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+  MSHADOW_REAL_TYPE_SWITCH(outputs[DotOut::out].type_flag_, DType, {
     // VectorDot() with fp16 is not supported in mshadow. Dispatch to dot() instead.
-    if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1 && inputs[0].type_flag_ != kFloat16) {
-      CHECK_NE(req[0], kAddTo) << "AddTo not yet supported";
-      Tensor<xpu, 1, DType> out = outputs[0].get<xpu, 1, DType>(s);
-      VectorDot(out, inputs[0].get<xpu, 1, DType>(s), inputs[1].get<xpu, 1, DType>(s));
+    if (inputs[DotIn::lhs].ndim() == 1 && inputs[DotIn::rhs].ndim() == 1 &&
+        inputs[DotIn::lhs].type_flag_ != kFloat16) {
+      CHECK_NE(req[DotOut::out], kAddTo) << "AddTo not yet supported";
+      Tensor<xpu, 1, DType> out = outputs[DotOut::out].get<xpu, 1, DType>(s);
+      VectorDot(
+          out, inputs[DotIn::lhs].get<xpu, 1, DType>(s), inputs[DotIn::rhs].get<xpu, 1, DType>(s));
     } else {
       index_t ma, na, mb, nb, m, n;
       if (param.transpose_a) {
-        ma = inputs[0].size(0);
-        na = inputs[0].Size() / ma;
+        ma = inputs[DotIn::lhs].size(0);
+        na = inputs[DotIn::lhs].Size() / ma;
         m  = na;
       } else {
-        na = inputs[0].size(inputs[0].ndim() - 1);
-        ma = inputs[0].Size() / na;
+        na = inputs[DotIn::lhs].size(inputs[DotIn::lhs].ndim() - 1);
+        ma = inputs[DotIn::lhs].Size() / na;
         m  = ma;
       }
       if (param.transpose_b) {
-        nb = inputs[1].size(inputs[1].ndim() - 1);
-        mb = inputs[1].Size() / nb;
+        nb = inputs[DotIn::rhs].size(inputs[DotIn::rhs].ndim() - 1);
+        mb = inputs[DotIn::rhs].Size() / nb;
         n  = mb;
       } else {
-        mb = inputs[1].size(0);
-        nb = inputs[1].Size() / mb;
+        mb = inputs[DotIn::rhs].size(0);
+        nb = inputs[DotIn::rhs].Size() / mb;
         n  = nb;
       }
-      Tensor<xpu, 2, DType> input0 = inputs[0].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
-      Tensor<xpu, 2, DType> input1 = inputs[1].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
-      Tensor<xpu, 2, DType> out    = outputs[0].get_with_shape<xpu, 2, DType>(Shape2(m, n), s);
+      Tensor<xpu, 2, DType> input0 =
+          inputs[DotIn::lhs].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
+      Tensor<xpu, 2, DType> input1 =
+          inputs[DotIn::rhs].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
+      Tensor<xpu, 2, DType> out =
+          outputs[DotOut::out].get_with_shape<xpu, 2, DType>(Shape2(m, n), s);
       if (param.transpose_a && param.transpose_b) {
-        ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T()));
+        ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0.T(), input1.T()));
       } else if (!param.transpose_a && param.transpose_b) {
-        ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T()));
+        ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0, input1.T()));
       } else if (param.transpose_a && !param.transpose_b) {
-        ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1));
+        ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0.T(), input1));
       } else {
-        ASSIGN_DISPATCH(out, req[0], dot(input0, input1));
+        ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0, input1));
       }
     }
   });
@@ -257,8 +267,14 @@
     // dns, dns -> dns
     target_stype = hint_has_value ? target_stype : kDefaultStorage;
     if (target_stype == kDefaultStorage) {
-      dispatched =
-          storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
+      dispatched = storage_type_assign(&out_stype,
+                                       kDefaultStorage,
+                                       dispatch_mode,
+#if MXNET_USE_ONEDNN == 1
+                                       DispatchMode::kFComputeEx);
+#else
+                                       DispatchMode::kFCompute);
+#endif
     }
   }
   if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
@@ -1369,6 +1385,14 @@
   return shape_is_known((*out_attrs)[0]);
 }
 
+#if MXNET_USE_ONEDNN == 1
+void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs);
+#endif
+
 template <typename xpu>
 void DotForwardEx(const nnvm::NodeAttrs& attrs,
                   const OpContext& ctx,
@@ -1378,36 +1402,78 @@
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
+#if MXNET_USE_ONEDNN == 1
+  if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
+      common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
+    if (std::is_same<xpu, cpu>::value) {
+      DotForwardExDNNL(attrs, ctx, inputs, req, outputs);
+    } else {
+      FallBackCompute(DotForward_<gpu>, attrs, ctx, inputs, req, outputs);
+    }
+    return;
+  }
+#endif
   const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
-  CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs";
-  CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs";
-  auto lhs_stype = inputs[0].storage_type();
-  auto rhs_stype = inputs[1].storage_type();
-  auto out_stype = outputs[0].storage_type();
+  CHECK_EQ(inputs[DotIn::lhs].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs";
+  CHECK_EQ(inputs[DotIn::rhs].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs";
+  auto lhs_stype = inputs[DotIn::lhs].storage_type();
+  auto rhs_stype = inputs[DotIn::rhs].storage_type();
+  auto out_stype = outputs[DotOut::out].storage_type();
   if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage &&
       !param.transpose_b) {
-    TBlob ret = outputs[0].data();
-    DotCsrDnsDnsImpl(ctx, xpu(), inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret);
+    TBlob ret = outputs[DotOut::out].data();
+    DotCsrDnsDnsImpl(ctx,
+                     xpu(),
+                     inputs[DotIn::lhs],
+                     inputs[DotIn::rhs].data(),
+                     req[DotOut::out],
+                     param.transpose_a,
+                     &ret);
   } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage &&
              out_stype == kDefaultStorage && !param.transpose_b) {
-    TBlob ret = outputs[0].data();
-    DotCsrRspDnsImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
+    TBlob ret = outputs[DotOut::out].data();
+    DotCsrRspDnsImpl(ctx,
+                     xpu(),
+                     inputs[DotIn::lhs],
+                     inputs[DotIn::rhs],
+                     req[DotOut::out],
+                     param.transpose_a,
+                     &ret);
   } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage &&
              out_stype == kRowSparseStorage && !param.transpose_b) {
-    NDArray out = outputs[0];
-    DotCsrDnsRspImpl(ctx, xpu(), inputs[0], inputs[1].data(), req[0], param.transpose_a, &out);
+    NDArray out = outputs[DotOut::out];
+    DotCsrDnsRspImpl(ctx,
+                     xpu(),
+                     inputs[DotIn::lhs],
+                     inputs[DotIn::rhs].data(),
+                     req[DotOut::out],
+                     param.transpose_a,
+                     &out);
   } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage &&
              out_stype == kRowSparseStorage && !param.transpose_b) {
-    NDArray ret = outputs[0];
-    DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
+    NDArray ret = outputs[DotOut::out];
+    DotCsrRspRspImpl(ctx,
+                     xpu(),
+                     inputs[DotIn::lhs],
+                     inputs[DotIn::rhs],
+                     req[DotOut::out],
+                     param.transpose_a,
+                     &ret);
   } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && out_stype == kCSRStorage &&
              !(param.transpose_a || param.transpose_b)) {
-    NDArray ret = outputs[0];
-    DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret);
+    NDArray ret = outputs[DotOut::out];
+    DotDnsCsrCsrImpl(
+        ctx, xpu(), inputs[DotIn::lhs].data(), inputs[DotIn::rhs], req[DotOut::out], &ret);
   } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
              out_stype == kDefaultStorage && !(param.transpose_a)) {
-    NDArray ret = outputs[0];
-    DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret, param.transpose_b);
+    NDArray ret = outputs[DotOut::out];
+    DotDnsCsrDnsImpl(ctx,
+                     xpu(),
+                     inputs[DotIn::lhs].data(),
+                     inputs[DotIn::rhs],
+                     req[DotOut::out],
+                     &ret,
+                     param.transpose_b);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
@@ -1471,47 +1537,72 @@
   using namespace mshadow;
   using namespace mshadow::expr;
   mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
-  if (req[0] == kNullOp)
+  if (req[DotOut::out] == kNullOp)
     return;
   const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
-  CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_)
+  CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::lhs].type_flag_)
       << "Binary function only support input/output with the same type";
-  CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_)
+  CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::rhs].type_flag_)
       << "Binary function only support input/output with the same type";
-  CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64 ||
-        (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
+  CHECK(outputs[DotOut::out].type_flag_ == kFloat32 ||
+        outputs[DotOut::out].type_flag_ == kFloat64 ||
+        (outputs[DotOut::out].type_flag_ == kFloat16 &&
+         ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
       << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    int ndim = outputs[0].ndim();
-    if (outputs[0].shape_.Size() == 0 || inputs[0].shape_.Size() == 0 ||
-        inputs[1].shape_.Size() == 0) {
-      if (outputs[0].shape_.Size() != 0 && req[0] != kAddTo) {
+  MSHADOW_REAL_TYPE_SWITCH(outputs[DotOut::out].type_flag_, DType, {
+    int ndim = outputs[DotOut::out].ndim();
+    if (outputs[DotOut::out].shape_.Size() == 0 || inputs[DotIn::lhs].shape_.Size() == 0 ||
+        inputs[DotIn::rhs].shape_.Size() == 0) {
+      if (outputs[DotOut::out].shape_.Size() != 0 && req[DotOut::out] != kAddTo) {
         mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
-            s, outputs[0].shape_.Size(), outputs[0].dptr<DType>());
+            s, outputs[DotOut::out].shape_.Size(), outputs[DotOut::out].dptr<DType>());
       }
       return;
     }
-    size_t batch_size                  = outputs[0].shape_.ProdShape(0, ndim - 2);
-    mshadow::Tensor<xpu, 3, DType> out = outputs[0].get_with_shape<xpu, 3, DType>(
-        Shape3(batch_size, outputs[0].shape_[ndim - 2], outputs[0].shape_[ndim - 1]), s);
-    mshadow::Tensor<xpu, 3, DType> mlhs = inputs[0].get_with_shape<xpu, 3, DType>(
-        Shape3(batch_size, inputs[0].shape_[ndim - 2], inputs[0].shape_[ndim - 1]), s);
-    mshadow::Tensor<xpu, 3, DType> mrhs = inputs[1].get_with_shape<xpu, 3, DType>(
-        Shape3(batch_size, inputs[1].shape_[ndim - 2], inputs[1].shape_[ndim - 1]), s);
+    size_t batch_size                  = outputs[DotOut::out].shape_.ProdShape(0, ndim - 2);
+    mshadow::Tensor<xpu, 3, DType> out = outputs[DotOut::out].get_with_shape<xpu, 3, DType>(
+        Shape3(batch_size,
+               outputs[DotOut::out].shape_[ndim - 2],
+               outputs[DotOut::out].shape_[ndim - 1]),
+        s);
+    mshadow::Tensor<xpu, 3, DType> mlhs = inputs[DotIn::lhs].get_with_shape<xpu, 3, DType>(
+        Shape3(
+            batch_size, inputs[DotIn::lhs].shape_[ndim - 2], inputs[DotIn::lhs].shape_[ndim - 1]),
+        s);
+    mshadow::Tensor<xpu, 3, DType> mrhs = inputs[DotIn::rhs].get_with_shape<xpu, 3, DType>(
+        Shape3(
+            batch_size, inputs[DotIn::rhs].shape_[ndim - 2], inputs[DotIn::rhs].shape_[ndim - 1]),
+        s);
     mshadow::Tensor<xpu, 1, DType*> workspace =
         ctx.requested[0].get_space_typed<xpu, 1, DType*>(mshadow::Shape1(3 * out.size(0)), s);
     if (param.transpose_a && param.transpose_b) {
-      mshadow::BatchGEMM<true, true>(
-          out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+      mshadow::BatchGEMM<true, true>(out,
+                                     mlhs,
+                                     mrhs,
+                                     (DType)1.0f,
+                                     (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+                                     workspace);
     } else if (!param.transpose_a && param.transpose_b) {
-      mshadow::BatchGEMM<false, true>(
-          out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+      mshadow::BatchGEMM<false, true>(out,
+                                      mlhs,
+                                      mrhs,
+                                      (DType)1.0f,
+                                      (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+                                      workspace);
     } else if (param.transpose_a && !param.transpose_b) {
-      mshadow::BatchGEMM<true, false>(
-          out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+      mshadow::BatchGEMM<true, false>(out,
+                                      mlhs,
+                                      mrhs,
+                                      (DType)1.0f,
+                                      (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+                                      workspace);
     } else {
-      mshadow::BatchGEMM<false, false>(
-          out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+      mshadow::BatchGEMM<false, false>(out,
+                                       mlhs,
+                                       mrhs,
+                                       (DType)1.0f,
+                                       (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+                                       workspace);
     }
   });
 }
@@ -1565,4 +1656,16 @@
 
 }  // namespace op
 }  // namespace mxnet
+namespace std {
+template <>
+struct hash<mxnet::op::DotParam> {
+  size_t operator()(const mxnet::op::DotParam& val) {
+    size_t ret = 0;
+    ret        = dmlc::HashCombine(ret, val.transpose_a);
+    ret        = dmlc::HashCombine(ret, val.transpose_b);
+    ret        = dmlc::HashCombine(ret, val.forward_stype);
+    return ret;
+  }
+};
+}  // namespace std
 #endif  // MXNET_OPERATOR_TENSOR_DOT_INL_H_
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index defed11..0d62cac 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -24,9 +24,8 @@
 
 #include "./dot-inl.h"
 #if MXNET_USE_ONEDNN == 1
-#include "./../nn/dnnl/dnnl_base-inl.h"
-#include "./../nn/dnnl/dnnl_ops-inl.h"
-#include "./../nn/dnnl/dnnl_batch_dot-inl.h"
+#include "operator/nn/dnnl/dnnl_batch_dot-inl.h"
+#include "operator/nn/dnnl/dnnl_dot-inl.h"
 #endif  // MXNET_USE_ONEDNN
 
 namespace mxnet {
@@ -97,6 +96,9 @@
     .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
     .set_attr<FCompute>("FCompute<cpu>", DotForward_<cpu>)
     .set_attr<FComputeEx>("FComputeEx<cpu>", DotForwardEx<cpu>)
+#if MXNET_USE_ONEDNN == 1
+    .set_attr<bool>("TIsMKLDNN", true)
+#endif
     .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_dot"})
     .add_argument("lhs", "NDArray-or-Symbol", "The first input")
     .add_argument("rhs", "NDArray-or-Symbol", "The second input")
@@ -117,12 +119,26 @@
     .add_arguments(DotParam::__FIELDS__());
 
 #if MXNET_USE_ONEDNN == 1
+void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  if (SupportDNNLDot(inputs, outputs[DotOut::out])) {
+    DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    DNNLRun(DNNLDotForward<false>, attrs, ctx, inputs, req, outputs);
+    DNNL_OPCHECK_RUN(DotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+  } else {
+    FallBackCompute(DotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+  }
+}
+
 static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs,
                                  const OpContext& ctx,
                                  const std::vector<NDArray>& inputs,
                                  const std::vector<OpReqType>& req,
                                  const std::vector<NDArray>& outputs) {
-  if (SupportDNNLBatchDot(inputs, outputs[0])) {
+  if (SupportDNNLBatchDot(inputs, outputs[DotOut::out])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     DNNLRun(DNNLBatchDotForward<false>, attrs, ctx, inputs, req, outputs);
     DNNL_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);