blob: 39786462e7f8423f90af23ce2a912ba9e9f8d9c6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file dnnl_dot.cc
*/
#if MXNET_USE_ONEDNN == 1
#include <memory>
#include <unordered_map>
#include "dnnl_dot-inl.h"
namespace mxnet {
namespace op {
bool SupportDNNLDot(const std::vector<NDArray>& inputs, const NDArray& output) {
#if MXNET_USE_BLAS_MKL == 1
return false;
#endif
return inputs[DotIn::lhs].shape().Size() > 1 && inputs[DotIn::rhs].shape().Size() > 1 &&
inputs[DotIn::lhs].shape().ndim() > 0 && inputs[DotIn::rhs].shape().ndim() > 0 &&
output.shape().Size() != 0 && output.shape().ndim() > 0 && output.shape().ndim() <= 12 &&
(inputs[DotIn::lhs].dtype() == mshadow::kFloat32 ||
inputs[DotIn::lhs].dtype() == mshadow::kBfloat16);
}
DNNLDotFwd& DNNLDotFwd::GetCached(const DotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs,
const bool isNumpy) {
using dot_fwd_map = std::unordered_map<DotSignature, DNNLDotFwd, OpHash>;
#if DMLC_CXX11_THREAD_LOCAL
static thread_local dot_fwd_map fwds;
#else
static MX_THREAD_LOCAL dot_fwd_map fwds;
#endif
DotSignature key(param);
key.AddSign(inputs[DotIn::lhs]);
key.AddSign(inputs[DotIn::rhs]);
key.AddSign(outputs[DotOut::out]);
key.AddSign(static_cast<int>(isNumpy));
auto it = fwds.find(key);
if (it == fwds.end()) {
const DNNLDotFwd fwd(param, inputs, outputs, isNumpy);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}
auto GetMemoryDesc(const NDArray& tensor, int firstDim, int secondDim, const bool transpose) {
return dnnl::memory::desc(
dnnl::memory::dims{firstDim, secondDim},
get_dnnl_type(tensor.dtype()),
transpose ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab);
}
DNNLDotFwd::DNNLDotFwd(const DotParam& param,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs,
const bool isNumpy) {
auto shapeLhs = inputs[DotIn::lhs].shape(), shapeRhs = inputs[DotIn::rhs].shape();
auto ndimLhs = shapeLhs.ndim(), ndimRhs = shapeRhs.ndim();
dnnl::memory::desc lhs_md, rhs_md, out_md;
// NumPy expects more than 2 dimensional rhs tensor as Ax...xKxN which is different than NDArray's
// KxAx...xN format. For NumPy shape in rhs memory descriptor is going to be Kx(A*...*N),
// similarly to NDArray, but for it to match the actual data there is an additional reorder
// needed, permuting the last two axes Ax...xKxN -> Ax...xNxK. For this data to match Kx(A*...*N)
// shape format_tag needs to be set to ba. Reorder described above is implemented in Execute
// function.
const bool differentNumpy = isNumpy && ndimRhs > 2;
const int smallDimLhs = param.transpose_a ? shapeLhs[0] : shapeLhs[ndimLhs - 1];
const int bigDimLhs = shapeLhs.Size() / smallDimLhs;
const int smallDimRhs = param.transpose_b ?
shapeRhs[ndimRhs - 1] :
(differentNumpy ? shapeRhs[ndimRhs - 2] : shapeRhs[0]);
const int bigDimRhs = shapeRhs.Size() / smallDimRhs;
lhs_md = GetMemoryDesc(inputs[DotIn::lhs], bigDimLhs, smallDimLhs, param.transpose_a);
rhs_md = GetMemoryDesc(
inputs[DotIn::rhs], smallDimRhs, bigDimRhs, param.transpose_b || differentNumpy);
out_md = dnnl::memory::desc({bigDimLhs, bigDimRhs},
get_dnnl_type(outputs[DotOut::out].dtype()),
dnnl::memory::format_tag::any);
dnnl::matmul::desc fwd_desc(lhs_md, rhs_md, out_md);
fwd_pd = std::make_shared<dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<dot_fwd_t>(*fwd_pd);
}
void DNNLDotFwd::Execute(const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
const bool isNumpy) {
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto lhs = dnnl::memory(
fwd_pd->src_desc(), engine, reinterpret_cast<void*>(inputs[DotIn::lhs].data().dptr_));
auto rhs = dnnl::memory(fwd_pd->weights_desc(), engine);
auto ndimRhs = inputs[DotIn::rhs].shape().ndim();
if (isNumpy && ndimRhs > 2) {
// Necessity of this reorder is described in DNNLDotFwd constructor.
auto tmp_rhs = inputs[DotIn::rhs].GetDNNLData();
dnnl::memory::desc rhs_md(
dnnl::memory::dims(inputs[DotIn::rhs].shape().begin(), inputs[DotIn::rhs].shape().end()),
get_dnnl_type(inputs[DotIn::rhs].dtype()),
static_cast<dnnl::memory::format_tag>(GetPermutedFormat(ndimRhs)));
dnnl::memory tmp_rhs_dst(rhs_md, engine, rhs.get_data_handle());
const auto rhs_reorder_pd = dnnl::reorder::primitive_desc(*tmp_rhs, tmp_rhs_dst);
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(rhs_reorder_pd),
{{DNNL_ARG_FROM, *tmp_rhs}, {DNNL_ARG_TO, tmp_rhs_dst}});
} else {
rhs.set_data_handle(reinterpret_cast<void*>(inputs[DotIn::rhs].data().dptr_));
}
dnnl_output_t out_mem = CreateDNNLMem(
outputs[DotOut::out], fwd_pd->dst_desc(), req[DotOut::out], &inputs[DotIn::lhs]);
dnnl_args_map_t args = {
{DNNL_ARG_SRC, lhs},
{DNNL_ARG_WEIGHTS, rhs},
{DNNL_ARG_DST, *out_mem.second},
};
DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
CommitOutput(outputs[DotOut::out], out_mem);
DNNLStream::Get()->Submit();
}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1