blob: 54da2737126c0080abe59dd43d34fd73ad17cd70 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file dnnl_fc.cc
* \brief DNNL (Quantized) FullyConnected operator based on subgraph
* \author Ciyong Chen
*/
#if MXNET_USE_ONEDNN == 1
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "operator/nn/dnnl/dnnl_act-inl.h"
#include "operator/nn/dnnl/dnnl_base-inl.h"
#include "operator/nn/dnnl/dnnl_fully_connected-inl.h"
#include "operator/quantization/quantization_utils.h"
#include "operator/tensor/matrix_op-inl.h"
#include "operator/subgraph/common.h"
#include "dnnl_common.h"
#include "dnnl_fc-inl.h"
namespace mxnet {
namespace op {
class SgDNNLFCOp {
public:
explicit SgDNNLFCOp(const nnvm::NodeAttrs& attrs)
: subgraph_sym_(*attrs.subgraphs[0]),
attrs(attrs),
full_param_(nnvm::get<DNNLFCFullParam>(attrs.parsed)) {}
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
void Backward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
LOG(FATAL) << "Not implemented: subgraph oneDNN fully connected only supports "
"inference computation.";
}
private:
enum { kDataMin = 0, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax, kSumMin, kSumMax };
const size_t MIN_MAX_COUNT = 8;
const float u8_to_s8_scale = 0.5;
NDArray PrepareOutputWithSum(const NDArray& sum_input, const NDArray& output);
bool CheckInitializationConditions(const std::vector<NDArray>& inputs,
const std::vector<float>& min_max_vec,
bool is_channel_wise);
bool PrepareQuantization(const OpContext& ctx,
const std::vector<NDArray>& in_data,
const NDArray& output,
const std::vector<float>& min_max_vec);
dnnl::memory::desc CreateOutputMemoryDesc(const mxnet::TShape& oshape, int out_dtype);
void GetCachedWeightsAndBias(const NDArray& weight,
bool support_channelwise_scale,
bool has_bias);
nnvm::Symbol subgraph_sym_;
nnvm::NodeAttrs attrs;
DNNLFCFullParam full_param_;
bool initialized_{false};
bool reorder_data_{false};
bool inplace_{false};
dnnl_args_map_t args_;
std::shared_ptr<DNNLFullyConnectedForward> fwd_;
std::shared_ptr<dnnl::memory> cached_data_mem_;
std::shared_ptr<dnnl::memory> cached_out_mem_;
NDArray cached_weight_;
NDArray cached_bias_;
float cached_data_min_;
float cached_data_max_;
float cached_weight_min_;
float cached_weight_max_;
float cached_sum_min_;
float cached_sum_max_;
float cached_bias_min_;
float cached_bias_max_;
size_t weight_ver_;
size_t bias_ver_;
float cached_output_min_;
float cached_output_max_;
float data_scale_{0.0f};
std::vector<float> weight_scales_;
};
void SgDNNLFCOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
const FCInputIndex idx(full_param_);
CHECK_EQ(in_data.size(), idx.GetTotal());
const int out_index = 0;
const int out_min_index = 1;
const int out_max_index = 2;
const auto& default_param = full_param_.default_param;
const auto& dnnl_param = full_param_.dnnl_param;
NDArray output;
if (dnnl_param.with_sum) {
output = PrepareOutputWithSum(in_data[idx.sum], out_data[out_index]);
} else {
output = out_data[out_index];
}
if (!dnnl_param.quantized) {
// dispatch to float version
DNNLFCForwardImpl(full_param_, ctx, in_data, req, {output});
return;
}
const bool has_bias = !default_param.no_bias;
const bool channel_wise =
dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize.value();
std::vector<float> min_max_vec(MIN_MAX_COUNT);
min_max_vec[kDataMin] = 0.0f;
min_max_vec[kDataMax] = 0.0f;
min_max_vec[kWeightMin] = 0.0f;
min_max_vec[kWeightMax] = 0.0f;
min_max_vec[kBiasMin] = 0.0f;
min_max_vec[kBiasMax] = 0.0f;
min_max_vec[kSumMin] = idx.sum_min ? in_data[idx.sum_min].data().dptr<float>()[0] : 0.0f;
min_max_vec[kSumMax] = idx.sum_max ? in_data[idx.sum_max].data().dptr<float>()[0] : 0.0f;
NDArray data = in_data[idx.data];
const NDArray& weight = in_data[idx.weight];
if (!channel_wise) {
min_max_vec[kWeightMin] = in_data[idx.weight_min].data().dptr<float>()[0];
min_max_vec[kWeightMax] = in_data[idx.weight_max].data().dptr<float>()[0];
if (has_bias) {
min_max_vec[kBiasMin] = in_data[idx.bias_min].data().dptr<float>()[0];
min_max_vec[kBiasMax] = in_data[idx.bias_max].data().dptr<float>()[0];
}
}
min_max_vec[kDataMin] = in_data[idx.data_min].data().dptr<float>()[0];
min_max_vec[kDataMax] = in_data[idx.data_max].data().dptr<float>()[0];
initialized_ = CheckInitializationConditions(in_data, min_max_vec, channel_wise);
if (!initialized_) {
const auto engine = CpuEngine::Get()->get_engine();
cached_data_min_ = min_max_vec[kDataMin];
cached_data_max_ = min_max_vec[kDataMax];
cached_weight_min_ = min_max_vec[kWeightMin];
cached_weight_max_ = min_max_vec[kWeightMax];
weight_ver_ = weight.version();
cached_weight_ = weight;
cached_sum_min_ = min_max_vec[kSumMin];
cached_sum_max_ = min_max_vec[kSumMax];
if (has_bias) {
cached_bias_min_ = min_max_vec[kBiasMin];
cached_bias_max_ = min_max_vec[kBiasMax];
bias_ver_ = in_data[idx.bias].version();
cached_bias_ = in_data[idx.bias];
} else {
cached_bias_ = NDArray();
}
const mxnet::TShape ishape = data.shape();
const auto data_ndim = ishape.ndim();
if (data.IsDNNLData()) {
reorder_data_ = true;
data = data.Reorder2Default();
}
if (data_ndim != 2) {
if (!default_param.flatten) {
data =
data.DNNLDataReshape(Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
} else {
data = data.DNNLDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
}
}
// create cached out_md
dnnl::memory::desc out_md = CreateOutputMemoryDesc(output.shape(), output.dtype());
cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine);
bool support_channelwise_scale = PrepareQuantization(ctx, in_data, output, min_max_vec);
fwd_.reset(new DNNLFullyConnectedForward(full_param_,
ctx.is_train,
data,
cached_weight_,
(has_bias ? &cached_bias_ : nullptr),
out_md));
GetCachedWeightsAndBias(weight, support_channelwise_scale, has_bias);
const auto data_mem = static_cast<const dnnl::memory*>(data.GetDNNLData());
cached_data_mem_ = std::make_shared<dnnl::memory>(data_mem->get_desc(), engine);
args_[DNNL_ARG_SRC] = *cached_data_mem_;
args_[DNNL_ARG_WEIGHTS] = *static_cast<const dnnl::memory*>(cached_weight_.GetDNNLData());
if (has_bias)
args_[DNNL_ARG_BIAS] = *static_cast<const dnnl::memory*>(cached_bias_.GetDNNLData());
args_[DNNL_ARG_DST] = *cached_out_mem_;
initialized_ = true;
}
if (reorder_data_) {
data = data.Reorder2Default();
}
cached_data_mem_->set_data_handle(reinterpret_cast<void*>(data.data().dptr_));
cached_out_mem_->set_data_handle(reinterpret_cast<void*>(output.data().dptr_));
DNNLStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
DNNLStream::Get()->Submit();
if (!dnnl_param.enabled_float_output.has_value()) {
float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
*output_min_ptr = cached_output_min_;
*output_max_ptr = cached_output_max_;
}
}
NDArray SgDNNLFCOp::PrepareOutputWithSum(const NDArray& sum_input, const NDArray& output) {
if (!initialized_) {
// TODO(zhennan): Currently, dnnl fallback mechanism will break inplace option,
// which make check (req[out_index] == kWriteInplace) useless.
auto in_dnnl_mem = static_cast<const dnnl::memory*>(sum_input.GetDNNLData());
auto out_dnnl_mem = static_cast<const dnnl::memory*>(output.GetDNNLData());
if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
sum_input.dtype() == output.dtype()) {
inplace_ = true;
}
}
if (inplace_) {
return sum_input;
} else {
// Not in place: copy sum_input into output.
auto in_dnnl_mem = static_cast<const dnnl::memory*>(sum_input.GetDNNLData());
auto out_dnnl_mem = static_cast<const dnnl::memory*>(output.GetDNNLData());
if (output.dtype() == mshadow::kInt32) {
auto mem_desc = in_dnnl_mem->get_desc();
auto this_dtype = get_dnnl_type(mshadow::kInt32);
mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
dnnl_mem_ptr tmp_mem(new dnnl::memory(
mem_desc, CpuEngine::Get()->get_engine(), out_dnnl_mem->get_data_handle()));
DNNLStream::Get()->RegisterMem(tmp_mem);
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*in_dnnl_mem, *tmp_mem),
{{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
return NDArray(tmp_mem);
} else if (sum_input.dtype() == mshadow::kUint8 && output.dtype() == mshadow::kInt8) {
auto sum_mem_desc = in_dnnl_mem->get_desc();
auto out_dtype = get_dnnl_type(mshadow::kInt8);
sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
dnnl_mem_ptr tmp_mem(new dnnl::memory(
sum_mem_desc, CpuEngine::Get()->get_engine(), out_dnnl_mem->get_data_handle()));
DNNLStream::Get()->RegisterMem(tmp_mem);
std::vector<float> reorder_scale = {u8_to_s8_scale};
dnnl::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, reorder_scale);
const auto reorder_pd = dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
in_dnnl_mem->get_desc(),
CpuEngine::Get()->get_engine(),
sum_mem_desc,
reorder_attr);
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd),
{{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
return NDArray(tmp_mem);
} else {
dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
DNNLStream::Get()->RegisterMem(tmp_mem);
DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
return NDArray(tmp_mem);
}
}
}
bool SgDNNLFCOp::CheckInitializationConditions(const std::vector<NDArray>& inputs,
const std::vector<float>& min_max_vec,
bool is_channel_wise) {
if (initialized_ && dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
bool has_bias = !full_param_.default_param.no_bias;
if (cached_data_min_ != min_max_vec[kDataMin] || cached_data_max_ != min_max_vec[kDataMax] ||
cached_sum_min_ != min_max_vec[kSumMin] || cached_sum_max_ != min_max_vec[kSumMax]) {
return false;
}
if (is_channel_wise) {
if (weight_ver_ != inputs[fullc::kWeight].version() ||
(has_bias && (bias_ver_ != inputs[fullc::kBias].version()))) {
return false;
}
} else {
if (cached_weight_min_ != min_max_vec[kWeightMin] ||
cached_weight_max_ != min_max_vec[kWeightMax] ||
(has_bias && (cached_bias_min_ != min_max_vec[kBiasMin] ||
cached_bias_max_ != min_max_vec[kBiasMax]))) {
return false;
}
}
return true;
}
return false;
}
dnnl::memory::desc SgDNNLFCOp::CreateOutputMemoryDesc(const mxnet::TShape& oshape, int out_dtype) {
auto default_param = full_param_.default_param;
dnnl::memory::dims out_dims(2);
if (oshape.ndim() == 2) {
out_dims[0] = static_cast<index_t>(oshape[0]);
out_dims[1] = static_cast<index_t>(oshape[1]);
} else {
if (!default_param.flatten) {
out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() - 1));
out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
} else {
out_dims[0] = static_cast<index_t>(oshape[0]);
out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
}
}
dnnl::memory::desc out_md =
dnnl::memory::desc(out_dims,
get_dnnl_type(out_dtype),
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(2)));
return out_md;
}
bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
const std::vector<NDArray>& in_data,
const NDArray& output,
const std::vector<float>& min_max_vec) {
const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const FCInputIndex idx(full_param_);
bool support_channelwise_scale = false;
auto dnnl_param = full_param_.dnnl_param;
bool has_bias = !full_param_.default_param.no_bias;
const NDArray& data = in_data[fullc::kData];
const NDArray& weight = in_data[fullc::kWeight];
const bool channel_wise =
dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize.value();
CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_);
bool fuse_requantize = false;
// Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
if (dnnl_param.min_calib_range.has_value() && dnnl_param.max_calib_range.has_value()) {
cached_output_min_ = dnnl_param.min_calib_range.value();
cached_output_max_ = dnnl_param.max_calib_range.value();
support_channelwise_scale = true;
fuse_requantize = true;
}
if (dnnl_param.enabled_float_output.has_value()) {
support_channelwise_scale = true;
}
// channel_wise support_channelwise_scale result
// True True True
// True False Error
// False True/False False
if (channel_wise && !support_channelwise_scale) {
LOG(FATAL) << "Currently, channel-wise quantization requires fuse requantize or dequantize."
<< " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
<< " fuse requantize (outputs of FullyConnected are collected during calibration "
"phase),"
<< " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` and "
<< " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true (default is false)";
}
support_channelwise_scale = support_channelwise_scale && channel_wise;
if (support_channelwise_scale) {
MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
weight_scales_ = GetWeightScales<DType>(cached_weight_,
has_bias ? &cached_bias_ : nullptr,
data_scale_,
support_channelwise_scale);
});
} else {
weight_scales_.resize(1);
weight_scales_[0] =
GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, cached_weight_max_);
if (has_bias) {
if (cached_bias_.dtype() == mshadow::kInt8) {
float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_bias_min_, cached_bias_max_);
float bias_int32_rescale = data_scale_ * weight_scales_[0] / bias_scale;
// TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
// the maximum value of bias to INT_MAX / 2.
float bias_max_rescale =
MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_, cached_bias_max_) / bias_scale;
if (bias_int32_rescale > bias_max_rescale) {
// avoid overflow on bias
bias_int32_rescale = bias_max_rescale;
float weight_rescale = bias_int32_rescale * bias_scale / data_scale_ / weight_scales_[0];
int8_t* weight_ptr = weight.data().dptr<int8_t>();
size_t weight_size = weight.shape().Size();
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
}
weight_scales_[0] *= weight_rescale;
}
NDArray bias = in_data[fullc::kBias];
cached_bias_ =
NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, mshadow::kInt32);
int8_t* bias_ptr = bias.data().dptr<int8_t>();
int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
size_t bias_size = bias.shape().Size();
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
}
}
}
}
size_t num_channel = cached_weight_.shape()[0];
float out_scale = 1.0f;
if (fuse_requantize || dnnl_param.enabled_float_output.has_value()) {
float tmp_scale_ = 1.0f;
if (fuse_requantize) {
if (dnnl_param.with_eltwise) {
tmp_scale_ = 1.0 / data_scale_;
full_param_.eltwise_param.scale =
GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_);
} else {
out_scale = GetQuantizeScale(output.dtype(), cached_output_min_, cached_output_max_);
tmp_scale_ = out_scale / data_scale_;
}
} else {
tmp_scale_ = 1.0 / data_scale_;
}
if (support_channelwise_scale) {
full_param_.output_scales.resize(num_channel);
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
}
} else {
full_param_.output_scales.resize(1);
full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
}
} else {
Stream<cpu>* s = ctx.get_stream<cpu>();
if (data.dtype() == mshadow::kInt8) {
mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, cpu>::Launch(
s,
1,
&cached_output_min_,
&cached_output_max_,
&(min_max_vec[kDataMin]),
&(min_max_vec[kDataMax]),
&(min_max_vec[kWeightMin]),
&(min_max_vec[kWeightMax]));
} else {
mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, cpu>::Launch(
s,
1,
&cached_output_min_,
&cached_output_max_,
&(min_max_vec[kDataMin]),
&(min_max_vec[kDataMax]),
&(min_max_vec[kWeightMin]),
&(min_max_vec[kWeightMax]));
}
full_param_.output_scales.resize(0);
out_scale = data_scale_ * weight_scales_[0];
}
if (dnnl_param.with_sum && !dnnl_param.enabled_float_output.has_value()) {
float sum_in_scale =
GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, cached_sum_max_);
full_param_.sum_scale = out_scale / sum_in_scale;
if (in_data[idx.sum].dtype() == mshadow::kUint8 && output.dtype() == mshadow::kInt8) {
// In this case, reorder with scale 0.5 is used on in_data[idx.sum] to
// scale it to s8 range, so sum_scale has to be rescaled as well
full_param_.sum_scale *= 1.0 / u8_to_s8_scale;
}
}
return support_channelwise_scale;
}
void SgDNNLFCOp::GetCachedWeightsAndBias(const NDArray& weight,
bool support_channelwise_scale,
bool has_bias) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<DNNLFullyconSignature, std::pair<NDArray, NDArray>, OpHash>
fcWeightsAndBias;
#else
static MX_THREAD_LOCAL
std::unordered_map<DNNLFullyconSignature, std::pair<NDArray, NDArray>, OpHash>
fcWeightsAndBias;
#endif
static const bool use_cache = !(dmlc::GetEnv("MXNET_ONEDNN_DISABLE_FC_CACHE", 0));
const bool has_id = attrs.dict.count("__identifier__");
bool read_from_cache = false;
DNNLFullyconSignature key(full_param_);
if (use_cache && has_id) {
key.AddSign(fwd_->fwd_pd.weights_desc());
key.AddSign(attrs.dict["__identifier__"]);
key.AddSign(attrs.name);
auto it = fcWeightsAndBias.find(key);
if (it != fcWeightsAndBias.end()) {
cached_weight_ = it->second.first;
cached_bias_ = it->second.second;
read_from_cache = true;
common::LogOnce(
"oneDNN optimized version of FullyConnected for inference is being used. Weights and "
"bias are cached and can not be dynamically changed during runtime. To disable caching "
"mechanism use MXNET_ONEDNN_DISABLE_FC_CACHE=1.");
}
}
if (!read_from_cache) {
// convert weight and bias to the format that oneDNN requires
if (!full_param_.dnnl_param.quantized || support_channelwise_scale) {
dnnl::memory::desc bias_md;
if (has_bias)
bias_md = fwd_->fwd_pd.bias_desc();
ConvertWeightBias2DNNL(&cached_weight_,
&cached_bias_,
has_bias,
fwd_->fwd_pd.weights_desc(),
has_bias ? &bias_md : nullptr,
1,
data_scale_,
weight_scales_,
false);
} else {
const auto def_weight_mem = weight.GetDNNLData();
if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
auto weight_desc = fwd_->fwd_pd.weights_desc();
cached_weight_ = NDArray(&weight_desc);
auto cached_weight_mem = cached_weight_.GetDNNLData();
std::unordered_map<int, dnnl::memory> args(
{{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, *cached_weight_mem}});
DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem, *cached_weight_mem),
args);
}
}
if (use_cache && has_id)
AddToCache(&fcWeightsAndBias, key, {cached_weight_, cached_bias_});
}
}
static void SgDNNLFCParamParser(nnvm::NodeAttrs* attrs) {
// For backward compatible, with_relu->with_eltwise
auto legacy = attrs->dict.find("with_relu");
if (legacy != attrs->dict.end()) {
attrs->dict["with_eltwise"] = attrs->dict["with_relu"];
attrs->dict.erase(legacy);
}
DNNLFCFullParam full_param;
try {
full_param.dnnl_param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
auto subgraph_sym = attrs->subgraphs[0];
DFSVisit(subgraph_sym->outputs, [&](const nnvm::ObjectPtr& node) {
if (node->is_variable())
return;
auto& op_name = node->op()->name;
if (op_name == "FullyConnected") {
full_param.default_param = nnvm::get<FullyConnectedParam>(node->attrs.parsed);
} else if (SupportDNNLFCEltwiseFusion(op_name)) {
if (op_name == "Activation") {
const ActivationParam act_param = nnvm::get<ActivationParam>(node->attrs.parsed);
full_param.eltwise_param.alg = GetDNNLActAlgo(act_param);
} else if (op_name == "LeakyReLU") {
const auto act_param = nnvm::get<LeakyReLUParam>(node->attrs.parsed);
full_param.eltwise_param.alpha = act_param.slope;
full_param.eltwise_param.alg = GetDNNLActAlgo(act_param);
} else if (op_name == "clip") {
const ClipParam clip_param = nnvm::get<ClipParam>(node->attrs.parsed);
full_param.eltwise_param.alg = dnnl::algorithm::eltwise_bounded_relu;
full_param.eltwise_param.alpha = clip_param.a_max;
} else {
full_param.eltwise_param.alg = GetDNNLEltwiseAlgo(op_name);
}
}
});
attrs->parsed = std::move(full_param);
}
static std::vector<std::string> SgDNNLFCListInputNames(const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
auto const& dnnl_param = full_param.dnnl_param;
std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
if (dnnl_param.quantized) {
const bool channel_wise =
dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize;
input_names.emplace_back("data_min");
input_names.emplace_back("data_max");
if (!channel_wise) {
input_names.emplace_back("weight_min");
input_names.emplace_back("weight_max");
if (!full_param.default_param.no_bias) {
input_names.emplace_back("bias_min");
input_names.emplace_back("bias_max");
}
}
if (dnnl_param.with_sum && !dnnl_param.enabled_float_output.has_value()) {
input_names.emplace_back("sum_min");
input_names.emplace_back("sum_max");
}
}
return input_names;
}
static std::vector<std::string> SgDNNLFCListOutputNames(const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized) {
if (full_param.dnnl_param.enabled_float_output.has_value())
return std::vector<std::string>{"output"};
else
return std::vector<std::string>{"output", "output_min", "output_max"};
} else {
return std::vector<std::string>{"output"};
}
}
template <typename T>
static inline void FillBaseInputOutputInfo(const DNNLFCFullParam& param,
std::vector<T>* base_in_attrs,
std::vector<T>* base_out_attrs,
std::vector<T>* in_attrs,
std::vector<T>* out_attrs) {
auto base_num_inputs = FCInputIndex(param).GetBase();
base_out_attrs->push_back(out_attrs->at(0));
for (int i = 0; i < base_num_inputs; ++i) {
base_in_attrs->push_back(in_attrs->at(i));
}
}
static bool SgDNNLFCInferShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shapes,
mxnet::ShapeVector* out_shapes) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized) {
mxnet::ShapeVector base_in_shapes;
mxnet::ShapeVector base_out_shapes;
FillBaseInputOutputInfo(full_param, &base_in_shapes, &base_out_shapes, in_shapes, out_shapes);
bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes);
for (size_t i = 0; i < in_shapes->size(); ++i) {
if (i < base_in_shapes.size())
in_shapes->at(i) = base_in_shapes[i];
else
SHAPE_ASSIGN_CHECK(*in_shapes, i, Shape1(1));
}
out_shapes->at(0) = base_out_shapes[0];
if (!full_param.dnnl_param.enabled_float_output.has_value()) {
SHAPE_ASSIGN_CHECK(*out_shapes, 1, Shape1(1));
SHAPE_ASSIGN_CHECK(*out_shapes, 2, Shape1(1));
}
return ret;
} else {
return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes);
}
}
static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_types,
std::vector<int>* out_types) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized) {
const bool channel_wise = full_param.dnnl_param.channel_wise_quantize.has_value() &&
full_param.dnnl_param.channel_wise_quantize;
const FCInputIndex idx(full_param);
if (in_types->at(idx.data) == mshadow::kBfloat16) {
return false;
}
CHECK(in_types->at(idx.data) == mshadow::kInt8 || in_types->at(idx.data) == mshadow::kUint8)
<< "QuantizedFullyConnected data input only supports int8/uint8, while "
<< in_types->at(idx.data) << " is given.";
if (channel_wise) {
TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kFloat32);
if (idx.IsBiasExist()) {
TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kFloat32);
}
} else {
TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kInt8);
if (idx.IsBiasExist()) {
if (in_types->at(idx.bias) == -1) {
TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kInt32);
} else {
CHECK(in_types->at(idx.bias) == mshadow::kInt8 ||
in_types->at(idx.bias) == mshadow::kInt32)
<< "QuantizedFullyConnected bias input only supports int8/int32, while "
<< in_types->at(idx.bias) << " is given.";
}
}
}
if (idx.IsSumExist()) {
if (full_param.dnnl_param.enabled_float_output.has_value()) {
TYPE_ASSIGN_CHECK(*in_types, idx.sum, full_param.dnnl_param.enabled_float_output.value());
} else {
CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum) == mshadow::kUint8)
<< "QuantizedFullyConnected sum input only supports int8/uint8, while "
<< in_types->at(idx.sum) << " is given.";
}
}
for (size_t i = idx.data_min; i < in_types->size(); ++i) {
TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
}
if (full_param.dnnl_param.enabled_float_output.has_value()) {
TYPE_ASSIGN_CHECK(*out_types, 0, full_param.dnnl_param.enabled_float_output.value());
} else {
if (full_param.dnnl_param.min_calib_range.has_value() &&
full_param.dnnl_param.max_calib_range.has_value()) {
if (IsOutputUint8(full_param) &&
(!idx.IsSumExist() || in_types->at(idx.sum) == mshadow::kUint8)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
}
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32);
}
TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32);
TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32);
}
return true;
} else {
bool result = DefaultSubgraphOpType(attrs, in_types, out_types);
if (full_param.dnnl_param.enabled_float_output.has_value()) {
(*out_types)[0] = full_param.dnnl_param.enabled_float_output.value();
}
return result;
}
}
static bool SgDNNLFCStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized) {
std::vector<int> base_in_attrs;
std::vector<int> base_out_attrs;
FillBaseInputOutputInfo(full_param, &base_in_attrs, &base_out_attrs, in_attrs, out_attrs);
bool ret = DefaultSubgraphOpStorageType(
attrs, dev_mask, dispatch_mode, &base_in_attrs, &base_out_attrs);
for (size_t i = 0; i < in_attrs->size(); ++i) {
if (i < base_in_attrs.size())
in_attrs->at(i) = base_in_attrs[i];
else
type_assign(&in_attrs->at(i), mxnet::kDefaultStorage);
}
out_attrs->at(0) = base_out_attrs[0];
if (!full_param.dnnl_param.enabled_float_output.has_value()) {
type_assign(&out_attrs->at(1), mxnet::kDefaultStorage);
type_assign(&out_attrs->at(2), mxnet::kDefaultStorage);
}
return ret;
} else {
return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
}
}
std::vector<std::pair<int, int>> SgDNNLFCInplaceOption(const NodeAttrs& attrs) {
auto const& param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (param.dnnl_param.with_sum) {
return std::vector<std::pair<int, int>>{{FCInputIndex(param).sum, 0}};
} else {
return std::vector<std::pair<int, int>>();
}
}
static OpStatePtr CreateSgDNNLFCState(const nnvm::NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
const std::vector<int>& in_types) {
return OpStatePtr::Create<SgDNNLFCOp>(attrs);
}
static void SgDNNLFCForward(const OpStatePtr& state_pointer,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
SgDNNLFCOp& op = state_pointer.get_state<SgDNNLFCOp>();
op.Forward(ctx, inputs, req, outputs);
}
nnvm::ObjectPtr SgDNNLFCQuantizedOp(const NodeAttrs& attrs) {
nnvm::ObjectPtr node = nnvm::Node::Create();
node->attrs.op = Op::Get("_sg_onednn_fully_connected");
node->attrs.name = "quantized_" + attrs.name;
node->attrs.dict = attrs.dict;
node->attrs.dict["quantized"] = "True";
node->attrs.subgraphs.reserve(attrs.subgraphs.size());
for (auto sub : attrs.subgraphs) {
node->attrs.subgraphs.push_back(sub);
}
node->op()->attr_parser(&(node->attrs));
return node;
}
static bool SgDNNLAvoidFCQuantizeInput(const NodeAttrs& attrs,
const size_t index_to_check,
const std::string quantize_granularity) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
std::unordered_set<size_t> avoid_indexes;
FCInputIndex idx(full_param);
if (quantize_granularity == "channel-wise") {
avoid_indexes.insert(fullc::kWeight); // weight
if (!full_param.default_param.no_bias) {
avoid_indexes.insert(fullc::kBias); // bias
}
}
if (idx.IsSumInputFloat()) {
avoid_indexes.insert(idx.sum);
}
return avoid_indexes.count(index_to_check);
}
NNVM_REGISTER_OP(_sg_onednn_fully_connected)
.add_alias("_sg_mkldnn_fully_connected")
.describe(R"code(_sg_onednn_fully_connected)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
return FCInputIndex(full_param).GetTotal();
})
.set_num_outputs([](const NodeAttrs& attrs) {
auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
if (full_param.dnnl_param.quantized &&
!full_param.dnnl_param.enabled_float_output.has_value()) {
return 3;
}
return 1;
})
.set_attr_parser(SgDNNLFCParamParser)
.set_attr<nnvm::FListInputNames>("FListInputNames", SgDNNLFCListInputNames)
.set_attr<nnvm::FListOutputNames>("FListOutputNames", SgDNNLFCListOutputNames)
.set_attr<mxnet::FInferShape>("FInferShape", SgDNNLFCInferShape)
.set_attr<nnvm::FInferType>("FInferType", SgDNNLFCInferType)
.set_attr<FInferStorageType>("FInferStorageType", SgDNNLFCStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateSgDNNLFCState)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgDNNLFCForward)
.set_attr<bool>("TIsDNNL", true)
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FMutateInputs>("FMutateInputs", DefaultSubgraphOpMutableInputs)
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SgDNNLFCInplaceOption)
.set_attr<FQuantizable>("FQuantizable",
[](const NodeAttrs& attrs) { return QuantizeType::kMust; })
.set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLFCQuantizedOp)
.set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
.set_attr<FAvoidQuantizeInput>("FAvoidQuantizeInput", SgDNNLAvoidFCQuantizeInput);
} // namespace op
} // namespace mxnet
#endif // if MXNET_USE_ONEDNN == 1