blob: 85b915e11799f5d55e2990de2d45bdf2571589eb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file dnnl_quantized_conv.cc
* \brief
* \author Wenting Jiang, Xinyu Chen
*/
#if MXNET_USE_ONEDNN == 1
#include "operator/elemwise_op_common.h"
#include "operator/nn/convolution-inl.h"
#include "operator/nn/dnnl/dnnl_base-inl.h"
#include "operator/nn/dnnl/dnnl_convolution-inl.h"
#include "operator/tensor/matrix_op-inl.h"
#include "operator/quantization/quantization_utils.h"
namespace mxnet {
namespace op {
static void DNNLQuantizedConvForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
NDArray weight = in_data[conv::kWeight];
ConvolutionParam param = nnvm::get<ConvolutionParam>(attrs.parsed);
DNNLConvFullParam full_param;
full_param.conv_param = param;
full_param.dnnl_param.Init(std::unordered_map<std::string, std::string>());
auto& fwd = GetConvFwd(full_param,
ctx.is_train,
in_data[conv::kData],
in_data[conv::kWeight],
param.no_bias ? nullptr : &in_data[conv::kBias],
out_data[conv::kOut]);
auto fwd_src_desc = fwd.GetPd().src_desc();
auto data_mem = in_data[conv::kData].GetDNNLDataReorder(&fwd_src_desc);
const dnnl::memory* weight_mem;
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
auto fwd_weight_desc = fwd.GetPd().weights_desc();
weight.DNNLDataReorderAsync(&fwd_weight_desc);
weight_mem = GetWeights(weight, fwd_weight_desc, param.num_group);
} else {
weight_mem = weight.GetDNNLData();
}
auto out_mem = CreateDNNLMem(out_data[conv::kOut], fwd.GetPd().dst_desc(), req[conv::kOut]);
dnnl_args_map_t net_args;
if (!param.no_bias) {
auto fwd_bias_desc = fwd.GetPd().bias_desc();
const dnnl::memory* bias_mem = in_data[conv::kBias].GetDNNLDataReorder(&fwd_bias_desc);
net_args.insert({DNNL_ARG_BIAS, *bias_mem});
}
net_args.insert({DNNL_ARG_SRC, *data_mem});
net_args.insert({DNNL_ARG_WEIGHTS, *weight_mem});
net_args.insert({DNNL_ARG_DST, *out_mem.second});
DNNLStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
CommitOutput(out_data[conv::kOut], out_mem);
DNNLStream::Get()->Submit();
Stream<cpu>* s = ctx.get_stream<cpu>();
const size_t num_inputs = param.no_bias ? 2 : 3;
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
s,
1,
out_data[1].data().dptr<float>(),
out_data[2].data().dptr<float>(),
in_data[num_inputs].data().dptr<float>(),
in_data[num_inputs + 1].data().dptr<float>(),
in_data[num_inputs + 2].data().dptr<float>(),
in_data[num_inputs + 3].data().dptr<float>());
}
NNVM_REGISTER_OP(_contrib_quantized_conv)
.set_attr<FComputeEx>("FComputeEx<cpu>", DNNLQuantizedConvForward);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1