blob: 3f2a5e0be5cc1fd2c6aca03b1350eadfc618f43a [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file dnnl_softmax.cc
* \brief
* \author Da Zheng
*/
#if MXNET_USE_ONEDNN == 1
#include "dnnl_softmax-inl.h"
namespace mxnet {
namespace op {
// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data) {
const int ndim = data.shape().ndim();
const int out_dtype = param.dtype.has_value() ? param.dtype.value() : data.dtype();
return !(param.temperature.has_value() && param.temperature.value() == 0.0) &&
CheckAxis(param.axis, ndim) == (ndim - 1) && SupportDNNL<DNNLTypeMode::NoInt32>(data) &&
out_dtype == data.dtype();
}
void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const NDArray& in_data,
const OpReqType& req,
const NDArray& out_data) {
if (req == kNullOp)
return;
// Same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now
CHECK_NE(req, kAddTo);
const auto& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (param.temperature.has_value()) {
TmpMemMgr::Get()->Init(ctx.requested[0]);
}
const bool is_train = ctx.is_train;
const auto tensors = DNNLSoftmaxFwd::Tensors(in_data, out_data);
const auto& fwd = DNNLSoftmaxFwd::GetCached(param, tensors, is_train);
fwd.Execute(tensors);
}
DNNLSoftmaxFwd::Tensors::Tensors(const NDArray& data, const NDArray& output)
: data(data), out(output) {}
typedef ParamOpSign<SoftmaxParam> DNNLSoftmaxSignature;
DNNLSoftmaxFwd& DNNLSoftmaxFwd::GetCached(const SoftmaxParam& param,
const Tensors& tensors,
const bool is_train) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<DNNLSoftmaxSignature, DNNLSoftmaxFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<DNNLSoftmaxSignature, DNNLSoftmaxFwd, OpHash> fwds;
#endif
DNNLSoftmaxSignature key(param);
const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f;
const int axis = CheckAxis(param.axis, tensors.data.shape().ndim());
key.AddSign(axis);
key.AddSign(is_train);
key.AddSign(temperature);
key.AddSign(tensors.data);
key.AddSign(tensors.out);
auto it = fwds.find(key);
if (it == fwds.end()) {
DNNLSoftmaxFwd fwd(param, tensors, is_train);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}
DNNLSoftmaxFwd::DNNLSoftmaxFwd(const SoftmaxParam& param,
const Tensors& tensors,
const bool is_train) {
const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f;
const int axis = CheckAxis(param.axis, tensors.data.shape().ndim());
const auto input_mem = tensors.data.GetDNNLData();
softmax_pd = std::make_shared<softmax_fwd_pd_t>(GetSoftmaxFwdPd(*input_mem, axis, is_train));
softmax_fwd = std::make_shared<softmax_fwd_t>(*softmax_pd);
if (temperature != 1.0f) {
temperature_pd = std::make_shared<linear_pd_t>(GetTemperaturePd(*input_mem, temperature));
temperature_fwd = std::make_shared<linear_t>(*temperature_pd);
}
}
softmax_fwd_pd_t DNNLSoftmaxFwd::GetSoftmaxFwdPd(const dnnl::memory& input_mem,
const int axis,
const bool is_train) {
const auto data_md = input_mem.get_desc();
const auto cpu_engine = CpuEngine::Get()->get_engine();
const auto prop = is_train ? dnnl::prop_kind::forward_training : dnnl::prop_kind::forward_scoring;
const auto desc = dnnl::softmax_forward::desc(prop, data_md, axis);
return softmax_fwd_pd_t(desc, cpu_engine);
}
linear_pd_t DNNLSoftmaxFwd::GetTemperaturePd(const dnnl::memory& input_mem,
const float temperature) {
const auto data_md = input_mem.get_desc();
const auto cpu_engine = CpuEngine::Get()->get_engine();
const auto desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_scoring,
dnnl::algorithm::eltwise_linear,
data_md,
1.0f / temperature,
0.0f);
return linear_pd_t(desc, cpu_engine);
}
void DNNLSoftmaxFwd::Execute(const Tensors& tensors) const {
DNNLStream* stream = DNNLStream::Get();
auto original_input_mem = tensors.data.GetDNNLData();
auto softmax_pd_dst_desc = softmax_pd->dst_desc();
const auto out_mem = tensors.out.GetDNNLData(&softmax_pd_dst_desc);
dnnl::memory* softmax_input_mem;
if (temperature_pd) {
// check whether additional buffer is needed, when temperature parameter is being used
if (original_input_mem->get_desc() != out_mem->get_desc()) {
softmax_input_mem = TmpMemMgr::Get()->Alloc(original_input_mem->get_desc());
} else {
softmax_input_mem = const_cast<dnnl::memory*>(out_mem);
}
stream->RegisterPrimArgs(
*temperature_fwd,
{{DNNL_ARG_SRC, *original_input_mem}, {DNNL_ARG_DST, *softmax_input_mem}});
} else {
softmax_input_mem = const_cast<dnnl::memory*>(original_input_mem);
}
stream->RegisterPrimArgs(*softmax_fwd,
{{DNNL_ARG_SRC, *softmax_input_mem}, {DNNL_ARG_DST, *out_mem}});
stream->Submit();
}
softmax_bwd_pd_t DNNLSoftmaxBwd::GetSoftmaxBwdPd(const dnnl::memory& out_grad_mem,
const dnnl::memory& out_mem,
const int axis,
const softmax_fwd_pd_t& hint_fwd_pd) {
dnnl::memory::desc out_grad_md = out_grad_mem.get_desc();
dnnl::memory::desc out_md = out_mem.get_desc();
const auto cpu_engine = CpuEngine::Get()->get_engine();
const auto desc = dnnl::softmax_backward::desc(out_grad_md, out_md, axis);
return softmax_bwd_pd_t(desc, cpu_engine, hint_fwd_pd);
}
void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (req[0] == kNullOp)
return;
const auto& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (param.temperature.has_value()) {
TmpMemMgr::Get()->Init(ctx.requested[0]);
}
const auto tensors = DNNLSoftmaxBwd::Tensors(inputs, outputs);
const auto& bwd = DNNLSoftmaxBwd::GetCached(param, tensors);
bwd.Execute(tensors, req);
}
DNNLSoftmaxBwd::Tensors::Tensors(const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs)
: out_grad(inputs[0]), out(inputs[1]), data_grad(outputs[0]) {}
DNNLSoftmaxBwd& DNNLSoftmaxBwd::GetCached(const SoftmaxParam& param, const Tensors& tensors) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<DNNLSoftmaxSignature, DNNLSoftmaxBwd, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<DNNLSoftmaxSignature, DNNLSoftmaxBwd, OpHash> bwds;
#endif
const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f;
const int axis = CheckAxis(param.axis, tensors.out.shape().ndim());
DNNLSoftmaxSignature key(param);
key.AddSign(axis);
key.AddSign(tensors.out);
key.AddSign(tensors.out_grad);
key.AddSign(temperature);
auto it = bwds.find(key);
if (it == bwds.end()) {
DNNLSoftmaxBwd bwd(param, tensors);
it = AddToCache(&bwds, key, bwd);
}
return it->second;
}
DNNLSoftmaxBwd::DNNLSoftmaxBwd(const SoftmaxParam& param, const Tensors& tensors) {
const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f;
const int axis = CheckAxis(param.axis, tensors.out.shape().ndim());
const auto out_grad_mem = tensors.out_grad.GetDNNLData();
const auto out_mem = tensors.out.GetDNNLData();
const auto softmax_fwd_pd = DNNLSoftmaxFwd::GetSoftmaxFwdPd(*out_mem, axis, /*is_train*/ true);
softmax_bwd_pd = std::make_shared<softmax_bwd_pd_t>(
GetSoftmaxBwdPd(*out_grad_mem, *out_mem, axis, softmax_fwd_pd));
softmax_bwd = std::make_shared<softmax_bwd_t>(*softmax_bwd_pd);
if (temperature != 1.0f) { // avoid dividing by 1
temperature_pd =
std::make_shared<linear_pd_t>(DNNLSoftmaxFwd::GetTemperaturePd(*out_mem, temperature));
temperature_fwd = std::make_shared<linear_t>(*temperature_pd);
}
}
void DNNLSoftmaxBwd::Execute(const Tensors& tensors, const std::vector<OpReqType>& req) const {
DNNLStream* stream = DNNLStream::Get();
const auto original_out_grad_mem = tensors.out_grad.GetDNNLData();
const auto out_mem = tensors.out.GetDNNLData();
const auto data_grad_mem =
CreateDNNLMem(tensors.data_grad, softmax_bwd_pd->diff_src_desc(), req[0]);
dnnl::memory* out_grad_mem;
if (temperature_fwd) {
// check whether additional buffer is needed, when temperature parameter is being used
if (original_out_grad_mem->get_desc() != softmax_bwd_pd->diff_src_desc()) {
out_grad_mem = TmpMemMgr::Get()->Alloc(original_out_grad_mem->get_desc());
} else {
out_grad_mem = const_cast<dnnl::memory*>(data_grad_mem.second);
}
stream->RegisterPrimArgs(
*temperature_fwd, {{DNNL_ARG_SRC, *original_out_grad_mem}, {DNNL_ARG_DST, *out_grad_mem}});
} else {
out_grad_mem = const_cast<dnnl::memory*>(original_out_grad_mem);
}
dnnl_args_map_t args = {{DNNL_ARG_DST, *out_mem},
{DNNL_ARG_DIFF_DST, *out_grad_mem},
{DNNL_ARG_DIFF_SRC, *data_grad_mem.second}};
stream->RegisterPrimArgs(*softmax_bwd, args);
CommitOutput(tensors.data_grad, data_grad_mem);
stream->Submit();
}
} // namespace op
} // namespace mxnet
#endif