[operator] Integrate oneDNN matmul primitive to mxnet dot operator (#20911)
* [operator] Integrate oneDNN matmul primitive to mxnet dot operator
* Fix numpy dot test
* Fix reviews
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h
index 7c9a916..97aacc1 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -198,6 +198,7 @@
const NDArray& output);
bool SupportDNNLSoftmaxOutput(const SoftmaxOutputParam& param);
bool SupportDNNLTranspose(const NDArray& data);
+bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
@@ -540,6 +541,7 @@
dnnl_format_tag_t GetDefaultFormat(const dnnl::memory::desc& md);
dnnl_format_tag_t GetDefaultFormat(int num_dims);
+dnnl_format_tag_t GetPermutedFormat(int num_dims);
dnnl::memory::desc GetDesc(const dnnl::memory::desc& md, const dnnl_format_tag_t& format);
inline bool same_shape(const mxnet::TShape& shape, const dnnl_dims_t dims, int ndims) {
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index 05fabd5..ec50da7 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -347,6 +347,38 @@
}
}
+dnnl_format_tag_t GetPermutedFormat(int num_dims) {
+ switch (num_dims) {
+ case 1:
+ return dnnl_a;
+ case 2:
+ return dnnl_ba;
+ case 3:
+ return dnnl_acb;
+ case 4:
+ return dnnl_abdc;
+ case 5:
+ return dnnl_abced;
+ case 6:
+ return dnnl_abcdfe;
+ case 7:
+ return dnnl_abcdegf;
+ case 8:
+ return dnnl_abcdefhg;
+ case 9:
+ return dnnl_abcdefgih;
+ case 10:
+ return dnnl_abcdefghji;
+ case 11:
+ return dnnl_abcdefghikj;
+ case 12:
+ return dnnl_abcdefghijlk;
+ default:
+ LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for oneDNN";
+ return dnnl_format_tag_undef;
+ }
+}
+
dnnl_format_tag_t GetDefaultFormat(const dnnl::memory::desc& desc) {
return GetDefaultFormat(desc.data.ndims);
}
diff --git a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
index 4117b17..b48afd1 100644
--- a/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
+++ b/src/operator/nn/dnnl/dnnl_batch_dot-inl.h
@@ -38,9 +38,6 @@
namespace mxnet {
namespace op {
-enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
-enum DotOut { out = 0, out_min, out_max };
-
struct DNNLDotParam : public dmlc::Parameter<DNNLDotParam> {
bool transpose_a;
bool transpose_b;
diff --git a/src/operator/nn/dnnl/dnnl_dot-inl.h b/src/operator/nn/dnnl/dnnl_dot-inl.h
new file mode 100644
index 0000000..b375872
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_dot-inl.h
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_dot-inl.h
+ */
+
+#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_DOT_INL_H_
+#define MXNET_OPERATOR_NN_DNNL_DNNL_DOT_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <vector>
+
+#include "dnnl_base-inl.h"
+#include "operator/tensor/dot-inl.h"
+
+namespace mxnet {
+namespace op {
+
+using dot_fwd_t = dnnl::matmul;
+using dot_fwd_pd_t = dnnl::matmul::primitive_desc;
+
+typedef ParamOpSign<DotParam> DotSignature;
+
+class DNNLDotFwd {
+ public:
+ static DNNLDotFwd& GetCached(const DotParam& param,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs,
+ const bool isNumpy);
+
+ DNNLDotFwd(const DotParam& param,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs,
+ const bool isNumpy);
+
+ void Execute(const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs,
+ const bool isNumpy);
+
+ private:
+ std::shared_ptr<dot_fwd_t> fwd;
+ std::shared_ptr<dot_fwd_pd_t> fwd_pd;
+};
+
+template <bool isNumpy>
+void DNNLDotForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ DotParam param;
+ if (isNumpy) {
+ // NumPy version of dot operator does not support transpose flags.
+ param = DotParam();
+ param.transpose_a = false;
+ param.transpose_b = false;
+ } else {
+ param = nnvm::get<DotParam>(attrs.parsed);
+ }
+ DNNLDotFwd& fwd = DNNLDotFwd::GetCached(param, inputs, outputs, isNumpy);
+ fwd.Execute(inputs, req, outputs, isNumpy);
+}
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_NN_DNNL_DNNL_DOT_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_dot.cc b/src/operator/nn/dnnl/dnnl_dot.cc
new file mode 100644
index 0000000..3978646
--- /dev/null
+++ b/src/operator/nn/dnnl/dnnl_dot.cc
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_dot.cc
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <unordered_map>
+
+#include "dnnl_dot-inl.h"
+
+namespace mxnet {
+namespace op {
+
+bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output) {
+#if MXNET_USE_BLAS_MKL == 1
+ return false;
+#endif
+ return inputs[DotIn::lhs].shape().Size() > 1 && inputs[DotIn::rhs].shape().Size() > 1 &&
+ inputs[DotIn::lhs].shape().ndim() > 0 && inputs[DotIn::rhs].shape().ndim() > 0 &&
+ output.shape().Size() != 0 && output.shape().ndim() > 0 && output.shape().ndim() <= 12 &&
+ (inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
+ inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
+}
+
+DNNLDotFwd& DNNLDotFwd::GetCached(const DotParam& param,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs,
+ const bool isNumpy) {
+ using dot_fwd_map = std::unordered_map<DotSignature, DNNLDotFwd, OpHash>;
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local dot_fwd_map fwds;
+#else
+ static MX_THREAD_LOCAL dot_fwd_map fwds;
+#endif
+
+ DotSignature key(param);
+ key.AddSign(inputs[DotIn::lhs]);
+ key.AddSign(inputs[DotIn::rhs]);
+ key.AddSign(outputs[DotOut::out]);
+ key.AddSign(static_cast<int>(isNumpy));
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ const DNNLDotFwd fwd(param, inputs, outputs, isNumpy);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
+auto GetMemoryDesc(const NDArray& tensor, int firstDim, int secondDim, const bool transpose) {
+ return dnnl::memory::desc(
+ dnnl::memory::dims{firstDim, secondDim},
+ get_dnnl_type(tensor.dtype()),
+ transpose ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab);
+}
+
+DNNLDotFwd::DNNLDotFwd(const DotParam& param,
+ const std::vector<NDArray>& inputs,
+ const std::vector<NDArray>& outputs,
+ const bool isNumpy) {
+ auto shapeLhs = inputs[DotIn::lhs].shape(), shapeRhs = inputs[DotIn::rhs].shape();
+ auto ndimLhs = shapeLhs.ndim(), ndimRhs = shapeRhs.ndim();
+ dnnl::memory::desc lhs_md, rhs_md, out_md;
+ // NumPy expects more than 2 dimensional rhs tensor as Ax...xKxN which is different than NDArray's
+ // KxAx...xN format. For NumPy shape in rhs memory descriptor is going to be Kx(A*...*N),
+ // similarly to NDArray, but for it to match the actual data there is an additional reorder
+ // needed, permuting the last two axes Ax...xKxN -> Ax...xNxK. For this data to match Kx(A*...*N)
+ // shape format_tag needs to be set to ba. Reorder described above is implemented in Execute
+ // function.
+ const bool differentNumpy = isNumpy && ndimRhs > 2;
+ const int smallDimLhs = param.transpose_a ? shapeLhs[0] : shapeLhs[ndimLhs - 1];
+ const int bigDimLhs = shapeLhs.Size() / smallDimLhs;
+ const int smallDimRhs = param.transpose_b ?
+ shapeRhs[ndimRhs - 1] :
+ (differentNumpy ? shapeRhs[ndimRhs - 2] : shapeRhs[0]);
+ const int bigDimRhs = shapeRhs.Size() / smallDimRhs;
+
+ lhs_md = GetMemoryDesc(inputs[DotIn::lhs], bigDimLhs, smallDimLhs, param.transpose_a);
+ rhs_md = GetMemoryDesc(
+ inputs[DotIn::rhs], smallDimRhs, bigDimRhs, param.transpose_b || differentNumpy);
+ out_md = dnnl::memory::desc({bigDimLhs, bigDimRhs},
+ get_dnnl_type(outputs[DotOut::out].dtype()),
+ dnnl::memory::format_tag::any);
+ dnnl::matmul::desc fwd_desc(lhs_md, rhs_md, out_md);
+ fwd_pd = std::make_shared<dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
+ fwd = std::make_shared<dot_fwd_t>(*fwd_pd);
+}
+
+void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs,
+ const bool isNumpy) {
+ auto engine = mxnet::CpuEngine::Get()->get_engine();
+ auto lhs = dnnl::memory(
+ fwd_pd->src_desc(), engine, reinterpret_cast<void*>(inputs[DotIn::lhs].data().dptr_));
+ auto rhs = dnnl::memory(fwd_pd->weights_desc(), engine);
+ auto ndimRhs = inputs[DotIn::rhs].shape().ndim();
+ if (isNumpy && ndimRhs > 2) {
+ // Necessity of this reorder is described in DNNLDotFwd constructor.
+ auto tmp_rhs = inputs[DotIn::rhs].GetDNNLData();
+ dnnl::memory::desc rhs_md(
+ dnnl::memory::dims(inputs[DotIn::rhs].shape().begin(), inputs[DotIn::rhs].shape().end()),
+ get_dnnl_type(inputs[DotIn::rhs].dtype()),
+ static_cast<dnnl::memory::format_tag>(GetPermutedFormat(ndimRhs)));
+ dnnl::memory tmp_rhs_dst(rhs_md, engine, rhs.get_data_handle());
+ const auto rhs_reorder_pd = dnnl::reorder::primitive_desc(*tmp_rhs, tmp_rhs_dst);
+ DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(rhs_reorder_pd),
+ {{DNNL_ARG_FROM, *tmp_rhs}, {DNNL_ARG_TO, tmp_rhs_dst}});
+ } else {
+ rhs.set_data_handle(reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_));
+ }
+ dnnl_output_t out_mem = CreateDNNLMem(
+ outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]);
+
+ dnnl_args_map_t args = {
+ {DNNL_ARG_SRC, lhs},
+ {DNNL_ARG_WEIGHTS, rhs},
+ {DNNL_ARG_DST, *out_mem.second},
+ };
+
+ DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
+ CommitOutput(outputs[DotOut::out], out_mem);
+ DNNLStream::Get()->Submit();
+}
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/numpy/np_dot_forward.cc b/src/operator/numpy/np_dot_forward.cc
index 1c2da2d..8704860 100644
--- a/src/operator/numpy/np_dot_forward.cc
+++ b/src/operator/numpy/np_dot_forward.cc
@@ -22,7 +22,10 @@
* \brief CPU Implementation of numpy-compatible dot
*/
-#include "./np_dot-inl.h"
+#include "np_dot-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/nn/dnnl/dnnl_dot-inl.h"
+#endif
namespace mxnet {
namespace op {
@@ -101,6 +104,32 @@
return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
}
+#if MXNET_USE_ONEDNN == 1
+static void NumpyDotComputeExCPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ if (SupportDNNLDot(inputs, outputs[0])) {
+ DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLDotForward<true>, attrs, ctx, inputs, req, outputs);
+ DNNL_OPCHECK_RUN(NumpyDotForward<cpu>, attrs, ctx, inputs, req, outputs);
+ } else {
+ FallBackCompute(NumpyDotForward<cpu>, attrs, ctx, inputs, req, outputs);
+ }
+}
+
+inline static bool NumpyDotStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 2U);
+ CHECK_EQ(out_attrs->size(), 1U);
+ return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+#endif
+
NNVM_REGISTER_OP(_npi_dot)
.describe(R"doc(Dot product of two arrays. Specifically,
@@ -134,6 +163,11 @@
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyDotForward<cpu>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<bool>("TIsDNNL", true)
+ .set_attr<FComputeEx>("FComputeEx<cpu>", NumpyDotComputeExCPU)
+ .set_attr<FInferStorageType>("FInferStorageType", NumpyDotStorageType)
+#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_dot"})
.add_argument("a", "NDArray-or-Symbol", "First input")
.add_argument("b", "NDArray-or-Symbol", "Second input");
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 7cd9fa9..dac690f 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -44,6 +44,9 @@
namespace mxnet {
namespace op {
+enum DotIn { lhs = 0, rhs, lhs_min, lhs_max, rhs_min, rhs_max };
+enum DotOut { out = 0, out_min, out_max };
+
struct DotParam : public dmlc::Parameter<DotParam> {
bool transpose_a;
bool transpose_b;
@@ -109,50 +112,57 @@
using namespace mshadow::expr;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
- CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_)
+ CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::lhs].type_flag_)
<< "Binary function only support input/output with the same type";
- CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_)
+ CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::rhs].type_flag_)
<< "Binary function only support input/output with the same type";
- CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64 ||
- (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
+ CHECK(outputs[DotOut::out].type_flag_ == kFloat32 ||
+ outputs[DotOut::out].type_flag_ == kFloat64 ||
+ (outputs[DotOut::out].type_flag_ == kFloat16 &&
+ ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[DotOut::out].type_flag_, DType, {
// VectorDot() with fp16 is not supported in mshadow. Dispatch to dot() instead.
- if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1 && inputs[0].type_flag_ != kFloat16) {
- CHECK_NE(req[0], kAddTo) << "AddTo not yet supported";
- Tensor<xpu, 1, DType> out = outputs[0].get<xpu, 1, DType>(s);
- VectorDot(out, inputs[0].get<xpu, 1, DType>(s), inputs[1].get<xpu, 1, DType>(s));
+ if (inputs[DotIn::lhs].ndim() == 1 && inputs[DotIn::rhs].ndim() == 1 &&
+ inputs[DotIn::lhs].type_flag_ != kFloat16) {
+ CHECK_NE(req[DotOut::out], kAddTo) << "AddTo not yet supported";
+ Tensor<xpu, 1, DType> out = outputs[DotOut::out].get<xpu, 1, DType>(s);
+ VectorDot(
+ out, inputs[DotIn::lhs].get<xpu, 1, DType>(s), inputs[DotIn::rhs].get<xpu, 1, DType>(s));
} else {
index_t ma, na, mb, nb, m, n;
if (param.transpose_a) {
- ma = inputs[0].size(0);
- na = inputs[0].Size() / ma;
+ ma = inputs[DotIn::lhs].size(0);
+ na = inputs[DotIn::lhs].Size() / ma;
m = na;
} else {
- na = inputs[0].size(inputs[0].ndim() - 1);
- ma = inputs[0].Size() / na;
+ na = inputs[DotIn::lhs].size(inputs[DotIn::lhs].ndim() - 1);
+ ma = inputs[DotIn::lhs].Size() / na;
m = ma;
}
if (param.transpose_b) {
- nb = inputs[1].size(inputs[1].ndim() - 1);
- mb = inputs[1].Size() / nb;
+ nb = inputs[DotIn::rhs].size(inputs[DotIn::rhs].ndim() - 1);
+ mb = inputs[DotIn::rhs].Size() / nb;
n = mb;
} else {
- mb = inputs[1].size(0);
- nb = inputs[1].Size() / mb;
+ mb = inputs[DotIn::rhs].size(0);
+ nb = inputs[DotIn::rhs].Size() / mb;
n = nb;
}
- Tensor<xpu, 2, DType> input0 = inputs[0].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
- Tensor<xpu, 2, DType> input1 = inputs[1].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
- Tensor<xpu, 2, DType> out = outputs[0].get_with_shape<xpu, 2, DType>(Shape2(m, n), s);
+ Tensor<xpu, 2, DType> input0 =
+ inputs[DotIn::lhs].get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
+ Tensor<xpu, 2, DType> input1 =
+ inputs[DotIn::rhs].get_with_shape<xpu, 2, DType>(Shape2(mb, nb), s);
+ Tensor<xpu, 2, DType> out =
+ outputs[DotOut::out].get_with_shape<xpu, 2, DType>(Shape2(m, n), s);
if (param.transpose_a && param.transpose_b) {
- ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T()));
+ ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0.T(), input1.T()));
} else if (!param.transpose_a && param.transpose_b) {
- ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T()));
+ ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0, input1.T()));
} else if (param.transpose_a && !param.transpose_b) {
- ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1));
+ ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0.T(), input1));
} else {
- ASSIGN_DISPATCH(out, req[0], dot(input0, input1));
+ ASSIGN_DISPATCH(out, req[DotOut::out], dot(input0, input1));
}
}
});
@@ -257,8 +267,14 @@
// dns, dns -> dns
target_stype = hint_has_value ? target_stype : kDefaultStorage;
if (target_stype == kDefaultStorage) {
- dispatched =
- storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
+ dispatched = storage_type_assign(&out_stype,
+ kDefaultStorage,
+ dispatch_mode,
+#if MXNET_USE_ONEDNN == 1
+ DispatchMode::kFComputeEx);
+#else
+ DispatchMode::kFCompute);
+#endif
}
}
if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) {
@@ -1369,6 +1385,14 @@
return shape_is_known((*out_attrs)[0]);
}
+#if MXNET_USE_ONEDNN == 1
+void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+#endif
+
template <typename xpu>
void DotForwardEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -1378,36 +1402,78 @@
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
+#if MXNET_USE_ONEDNN == 1
+ if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
+ common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
+ if (std::is_same<xpu, cpu>::value) {
+ DotForwardExDNNL(attrs, ctx, inputs, req, outputs);
+ } else {
+ FallBackCompute(DotForward_<gpu>, attrs, ctx, inputs, req, outputs);
+ }
+ return;
+ }
+#endif
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
- CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs";
- CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs";
- auto lhs_stype = inputs[0].storage_type();
- auto rhs_stype = inputs[1].storage_type();
- auto out_stype = outputs[0].storage_type();
+ CHECK_EQ(inputs[DotIn::lhs].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs";
+ CHECK_EQ(inputs[DotIn::rhs].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs";
+ auto lhs_stype = inputs[DotIn::lhs].storage_type();
+ auto rhs_stype = inputs[DotIn::rhs].storage_type();
+ auto out_stype = outputs[DotOut::out].storage_type();
if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage &&
!param.transpose_b) {
- TBlob ret = outputs[0].data();
- DotCsrDnsDnsImpl(ctx, xpu(), inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret);
+ TBlob ret = outputs[DotOut::out].data();
+ DotCsrDnsDnsImpl(ctx,
+ xpu(),
+ inputs[DotIn::lhs],
+ inputs[DotIn::rhs].data(),
+ req[DotOut::out],
+ param.transpose_a,
+ &ret);
} else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage &&
out_stype == kDefaultStorage && !param.transpose_b) {
- TBlob ret = outputs[0].data();
- DotCsrRspDnsImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
+ TBlob ret = outputs[DotOut::out].data();
+ DotCsrRspDnsImpl(ctx,
+ xpu(),
+ inputs[DotIn::lhs],
+ inputs[DotIn::rhs],
+ req[DotOut::out],
+ param.transpose_a,
+ &ret);
} else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage &&
out_stype == kRowSparseStorage && !param.transpose_b) {
- NDArray out = outputs[0];
- DotCsrDnsRspImpl(ctx, xpu(), inputs[0], inputs[1].data(), req[0], param.transpose_a, &out);
+ NDArray out = outputs[DotOut::out];
+ DotCsrDnsRspImpl(ctx,
+ xpu(),
+ inputs[DotIn::lhs],
+ inputs[DotIn::rhs].data(),
+ req[DotOut::out],
+ param.transpose_a,
+ &out);
} else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage &&
out_stype == kRowSparseStorage && !param.transpose_b) {
- NDArray ret = outputs[0];
- DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
+ NDArray ret = outputs[DotOut::out];
+ DotCsrRspRspImpl(ctx,
+ xpu(),
+ inputs[DotIn::lhs],
+ inputs[DotIn::rhs],
+ req[DotOut::out],
+ param.transpose_a,
+ &ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && out_stype == kCSRStorage &&
!(param.transpose_a || param.transpose_b)) {
- NDArray ret = outputs[0];
- DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret);
+ NDArray ret = outputs[DotOut::out];
+ DotDnsCsrCsrImpl(
+ ctx, xpu(), inputs[DotIn::lhs].data(), inputs[DotIn::rhs], req[DotOut::out], &ret);
} else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
out_stype == kDefaultStorage && !(param.transpose_a)) {
- NDArray ret = outputs[0];
- DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret, param.transpose_b);
+ NDArray ret = outputs[DotOut::out];
+ DotDnsCsrDnsImpl(ctx,
+ xpu(),
+ inputs[DotIn::lhs].data(),
+ inputs[DotIn::rhs],
+ req[DotOut::out],
+ &ret,
+ param.transpose_b);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -1471,47 +1537,72 @@
using namespace mshadow;
using namespace mshadow::expr;
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
- if (req[0] == kNullOp)
+ if (req[DotOut::out] == kNullOp)
return;
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
- CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_)
+ CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::lhs].type_flag_)
<< "Binary function only support input/output with the same type";
- CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_)
+ CHECK_EQ(outputs[DotOut::out].type_flag_, inputs[DotIn::rhs].type_flag_)
<< "Binary function only support input/output with the same type";
- CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64 ||
- (outputs[0].type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
+ CHECK(outputs[DotOut::out].type_flag_ == kFloat32 ||
+ outputs[DotOut::out].type_flag_ == kFloat64 ||
+ (outputs[DotOut::out].type_flag_ == kFloat16 &&
+ ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
<< "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
- MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
- int ndim = outputs[0].ndim();
- if (outputs[0].shape_.Size() == 0 || inputs[0].shape_.Size() == 0 ||
- inputs[1].shape_.Size() == 0) {
- if (outputs[0].shape_.Size() != 0 && req[0] != kAddTo) {
+ MSHADOW_REAL_TYPE_SWITCH(outputs[DotOut::out].type_flag_, DType, {
+ int ndim = outputs[DotOut::out].ndim();
+ if (outputs[DotOut::out].shape_.Size() == 0 || inputs[DotIn::lhs].shape_.Size() == 0 ||
+ inputs[DotIn::rhs].shape_.Size() == 0) {
+ if (outputs[DotOut::out].shape_.Size() != 0 && req[DotOut::out] != kAddTo) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
- s, outputs[0].shape_.Size(), outputs[0].dptr<DType>());
+ s, outputs[DotOut::out].shape_.Size(), outputs[DotOut::out].dptr<DType>());
}
return;
}
- size_t batch_size = outputs[0].shape_.ProdShape(0, ndim - 2);
- mshadow::Tensor<xpu, 3, DType> out = outputs[0].get_with_shape<xpu, 3, DType>(
- Shape3(batch_size, outputs[0].shape_[ndim - 2], outputs[0].shape_[ndim - 1]), s);
- mshadow::Tensor<xpu, 3, DType> mlhs = inputs[0].get_with_shape<xpu, 3, DType>(
- Shape3(batch_size, inputs[0].shape_[ndim - 2], inputs[0].shape_[ndim - 1]), s);
- mshadow::Tensor<xpu, 3, DType> mrhs = inputs[1].get_with_shape<xpu, 3, DType>(
- Shape3(batch_size, inputs[1].shape_[ndim - 2], inputs[1].shape_[ndim - 1]), s);
+ size_t batch_size = outputs[DotOut::out].shape_.ProdShape(0, ndim - 2);
+ mshadow::Tensor<xpu, 3, DType> out = outputs[DotOut::out].get_with_shape<xpu, 3, DType>(
+ Shape3(batch_size,
+ outputs[DotOut::out].shape_[ndim - 2],
+ outputs[DotOut::out].shape_[ndim - 1]),
+ s);
+ mshadow::Tensor<xpu, 3, DType> mlhs = inputs[DotIn::lhs].get_with_shape<xpu, 3, DType>(
+ Shape3(
+ batch_size, inputs[DotIn::lhs].shape_[ndim - 2], inputs[DotIn::lhs].shape_[ndim - 1]),
+ s);
+ mshadow::Tensor<xpu, 3, DType> mrhs = inputs[DotIn::rhs].get_with_shape<xpu, 3, DType>(
+ Shape3(
+ batch_size, inputs[DotIn::rhs].shape_[ndim - 2], inputs[DotIn::rhs].shape_[ndim - 1]),
+ s);
mshadow::Tensor<xpu, 1, DType*> workspace =
ctx.requested[0].get_space_typed<xpu, 1, DType*>(mshadow::Shape1(3 * out.size(0)), s);
if (param.transpose_a && param.transpose_b) {
- mshadow::BatchGEMM<true, true>(
- out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+ mshadow::BatchGEMM<true, true>(out,
+ mlhs,
+ mrhs,
+ (DType)1.0f,
+ (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+ workspace);
} else if (!param.transpose_a && param.transpose_b) {
- mshadow::BatchGEMM<false, true>(
- out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+ mshadow::BatchGEMM<false, true>(out,
+ mlhs,
+ mrhs,
+ (DType)1.0f,
+ (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+ workspace);
} else if (param.transpose_a && !param.transpose_b) {
- mshadow::BatchGEMM<true, false>(
- out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+ mshadow::BatchGEMM<true, false>(out,
+ mlhs,
+ mrhs,
+ (DType)1.0f,
+ (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+ workspace);
} else {
- mshadow::BatchGEMM<false, false>(
- out, mlhs, mrhs, (DType)1.0f, (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, workspace);
+ mshadow::BatchGEMM<false, false>(out,
+ mlhs,
+ mrhs,
+ (DType)1.0f,
+ (kAddTo == req[DotOut::out]) ? (DType)1.0f : (DType)0.0f,
+ workspace);
}
});
}
@@ -1565,4 +1656,16 @@
} // namespace op
} // namespace mxnet
+namespace std {
+template <>
+struct hash<mxnet::op::DotParam> {
+ size_t operator()(const mxnet::op::DotParam& val) {
+ size_t ret = 0;
+ ret = dmlc::HashCombine(ret, val.transpose_a);
+ ret = dmlc::HashCombine(ret, val.transpose_b);
+ ret = dmlc::HashCombine(ret, val.forward_stype);
+ return ret;
+ }
+};
+} // namespace std
#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index defed11..0d62cac 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -24,9 +24,8 @@
#include "./dot-inl.h"
#if MXNET_USE_ONEDNN == 1
-#include "./../nn/dnnl/dnnl_base-inl.h"
-#include "./../nn/dnnl/dnnl_ops-inl.h"
-#include "./../nn/dnnl/dnnl_batch_dot-inl.h"
+#include "operator/nn/dnnl/dnnl_batch_dot-inl.h"
+#include "operator/nn/dnnl/dnnl_dot-inl.h"
#endif // MXNET_USE_ONEDNN
namespace mxnet {
@@ -97,6 +96,9 @@
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", DotForward_<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", DotForwardEx<cpu>)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<bool>("TIsMKLDNN", true)
+#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_dot"})
.add_argument("lhs", "NDArray-or-Symbol", "The first input")
.add_argument("rhs", "NDArray-or-Symbol", "The second input")
@@ -117,12 +119,26 @@
.add_arguments(DotParam::__FIELDS__());
#if MXNET_USE_ONEDNN == 1
+void DotForwardExDNNL(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ if (SupportDNNLDot(inputs, outputs[DotOut::out])) {
+ DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+ DNNLRun(DNNLDotForward<false>, attrs, ctx, inputs, req, outputs);
+ DNNL_OPCHECK_RUN(DotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+ } else {
+ FallBackCompute(DotForward_<cpu>, attrs, ctx, inputs, req, outputs);
+ }
+}
+
static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if (SupportDNNLBatchDot(inputs, outputs[0])) {
+ if (SupportDNNLBatchDot(inputs, outputs[DotOut::out])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNLRun(DNNLBatchDotForward<false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);