Restore quantized RNN to master (#20952)
* Restore quantized RNN
sanity
* Add tests & disable LSTMP from quantization
* apply review comments
* change link
* Add new lines at the EOF
* Add ops to amp lists
* Remove unused features
* Fix DataDesc handling in quantization
* fix website
* fix sanity
* remove magic number
Co-authored-by: Bartlomiej Gawrych <barlomiej.gawrych@intel.com>
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index c936d3e..0bc2a8f 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -345,6 +345,19 @@
/*!
* \brief Register a function to determine if the input of a quantized operator
+ * needs to be quantized asymmetrically.
+ */
+using FNeedAsymQuantizeInput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;
+
+/*!
+ * \brief Register a function to determine if the output of a quantized operator
+ * needs to be dequantized. This is usually used for the quantized operators
+ * which can produce fp32 outputs directly.
+ */
+using FAvoidDequantizeOutput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;
+
+/*!
+ * \brief Register a function to determine if the input of a quantized operator
* needs to be calibrated. This is usually used for the quantized operators
* which need calibration on its input.
*/
diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py
index 27a5e3a..bab5268 100644
--- a/python/mxnet/amp/lists/symbol_bf16.py
+++ b/python/mxnet/amp/lists/symbol_bf16.py
@@ -97,6 +97,7 @@
'_contrib_index_copy',
'_contrib_quadratic',
'_contrib_quantize',
+ '_contrib_quantize_asym',
'_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
@@ -105,6 +106,7 @@
'_contrib_quantized_pooling',
'_contrib_quantized_elemwise_add',
'_contrib_quantized_act',
+ '_contrib_quantized_rnn',
'_image_crop',
'_linspace',
'_contrib_requantize',
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index 52a8f45..ad1f0ad 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -99,6 +99,7 @@
'_contrib_index_copy',
'_contrib_quadratic',
'_contrib_quantize',
+ '_contrib_quantize_asym',
'_contrib_quantize_v2',
'_contrib_quantized_concat',
'_contrib_quantized_conv',
@@ -108,6 +109,7 @@
'_contrib_quantized_elemwise_add',
'_contrib_quantized_act',
'_contrib_quantized_reshape',
+ '_contrib_quantized_rnn',
'_contrib_quantized_transpose',
'_npx_quantized_reshape',
'_npx_quantized_transpose',
diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 10d2455..6494235 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -33,6 +33,20 @@
from ..util import is_np_array, wrap_ctx_to_device_func
+def _multilist_iterator(arg, func):
+ """Iterate over multidiemnsional list and returns new list
+ with same dimensions, but applied `func` function on list elements.
+ E.g. _multilist_iterator([1, 2, [3, 4]], lambda x: x**2) = [1, 4, [9, 16]]
+ """
+ ret = []
+ if isinstance(arg, list):
+ for el in arg:
+ ret.append(_multilist_iterator(el, func))
+ else:
+ return func(arg)
+
+ return ret
+
def _quantize_params(qsym, params, min_max_dict):
"""Given a quantized symbol and a dict of params that have not been quantized,
generate quantized params. Currently only supports quantizing the arg_params
@@ -357,7 +371,7 @@
for batch in data:
if not isinstance(batch, list):
batch = [batch]
- batch = [b.as_in_context(mx.cpu()) for b in batch]
+ batch = _multilist_iterator(batch, lambda b: b.as_in_context(mx.cpu()))
sym_block(*batch[:num_inputs])
num_batches += 1
if num_calib_batches is not None and num_batches >= num_calib_batches:
@@ -368,19 +382,44 @@
def _generate_list_of_data_desc(data_shapes, data_types):
- """"Convert list ot tuples to list of DataDesc."""
- if isinstance(data_shapes, list):
- if all(isinstance(x, DataDesc) for x in data_shapes):
- return data_shapes
- if all(isinstance(x, tuple) for x in data_shapes):
- if len(data_shapes) == 1:
- data_shapes = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])]
+ """Convert list of tuples to list of DataDesc."""
+ def flatten_list(arg):
+ ret = []
+ for el in arg:
+ if isinstance(el, list):
+ ret += flatten_list(el)
else:
- data_shapes = [DataDesc(name='data' + str(i), shape=data_shapes[i],
- dtype=data_types[i]) for i in range(len(data_shapes))]
- return data_shapes
- raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple')
+ ret.append(el)
+ return ret
+ flattened_data_types = flatten_list(data_types)
+ flattened_data_shapes = flatten_list(data_shapes)
+
+ if all(isinstance(x, DataDesc) for x in flattened_data_shapes):
+ return data_shapes
+
+ assert len(flattened_data_types) == len(flattened_data_shapes)
+
+ # pass integral type as reference
+ counter = [0]
+ def get_data_desc(data_shape, counter=counter, data_types=flattened_data_types):
+ if isinstance(data_shape, DataDesc):
+ return data_shape
+ elif isinstance(data_shape, tuple):
+ desc = DataDesc(name='data' + str(counter[0]), shape=data_shape,
+ dtype=data_types[counter[0]])
+ counter[0] += 1
+ return desc
+ else:
+ raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple')
+
+
+ if len(data_shapes) == 1 and not isinstance(data_shapes[0], list):
+ data_descs = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])]
+ else:
+ data_descs = _multilist_iterator(data_shapes, get_data_desc)
+
+ return data_descs
@wrap_ctx_to_device_func
def quantize_model(sym, arg_params, aux_params, data_names=('data',),
@@ -841,8 +880,8 @@
x = iter(calib_data)
batch = next(x)
if isinstance(batch, list):
- data_shapes = [b.shape for b in batch]
- data_types = [b.dtype for b in batch]
+ data_shapes = _multilist_iterator(batch, lambda x: x.shape)
+ data_types = _multilist_iterator(batch, lambda x: x.dtype)
else:
data_shapes = [batch.shape]
data_types = [batch.dtype]
@@ -850,16 +889,15 @@
raise ValueError('calib_data expects mx.gluon.data.DataLoader')
if data_types is None:
- data_types = [mx_real_t] * len(data_shapes)
+ data_types = _multilist_iterator(data_shapes, lambda x: mx_real_t)
+
data_descs = _generate_list_of_data_desc(data_shapes, data_types)
num_inputs = len(data_descs)
data_nd = []
- for desc in data_descs:
- if is_np_array():
- data_nd.append(mx.np.zeros(shape=desc.shape, dtype=desc.dtype))
- else:
- data_nd.append(mx.nd.zeros(shape=desc.shape, dtype=desc.dtype))
+ arr_fn = mx.np if is_np_array() else mx.nd
+ data_nd = _multilist_iterator(data_descs, lambda d, F=arr_fn: F.zeros(shape=d.shape, dtype=d.dtype))
+
while True:
try:
network(*data_nd)
@@ -919,7 +957,7 @@
raise ValueError(
'calib_data must be provided when calib_mode=%s' % calib_mode)
if calib_mode in ['naive', 'entropy', 'custom']:
- inputs = [mx.sym.var(desc.name) for desc in data_descs]
+ inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name))
calib_net = SymbolBlock(symnet, inputs)
for k, v in calib_net.collect_params().items():
v.grad_req = 'null'
@@ -939,7 +977,7 @@
else:
raise ValueError('calib_mode has to be one of: naive, entropy, custom')
elif calib_mode is not None and calib_mode == 'none':
- inputs = [mx.sym.var(desc.name) for desc in data_descs]
+ inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name))
net = SymbolBlock(qsym, inputs)
for k, v in net.collect_params().items():
diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py
index 4d78cd9..b0a0129 100644
--- a/python/mxnet/io/io.py
+++ b/python/mxnet/io/io.py
@@ -643,8 +643,11 @@
@property
def provide_label(self):
"""The name and shape of label provided by this iterator."""
+ batch_axis = self.layout.find('N')
return [
- DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype)
+ DataDesc(k, tuple(list(v.shape[:batch_axis]) + \
+ [self.batch_size] + list(v.shape[batch_axis + 1:])),
+ v.dtype, layout=self.layout)
for k, v in self.label
]
diff --git a/src/operator/nn/dnnl/dnnl_rnn-inl.h b/src/operator/nn/dnnl/dnnl_rnn-inl.h
index f287534..6165dfa 100644
--- a/src/operator/nn/dnnl/dnnl_rnn-inl.h
+++ b/src/operator/nn/dnnl/dnnl_rnn-inl.h
@@ -32,10 +32,20 @@
#include "operator/rnn-inl.h"
#include "dnnl_base-inl.h"
+#include "operator/quantization/quantized_rnn-inl.h"
namespace mxnet {
namespace op {
+struct DNNLRnnParam : public dmlc::Parameter<DNNLRnnParam> {
+ bool quantized;
+
+ DMLC_DECLARE_PARAMETER(DNNLRnnParam) {
+ DMLC_DECLARE_FIELD(quantized).set_default(false).describe(
+ "Whether it's a quantized RNN operator");
+ }
+};
+
struct DNNLRnnLayerParam {
using memory = dnnl::memory;
using dims = dnnl::memory::dims;
@@ -66,6 +76,10 @@
size_t native_single_b_size; // bias size of a single cell from framework
size_t single_state_size; // state size of a single cell, hy, cy
+ bool quantized; // whether this layer is quantized
+ bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the
+ // quantized rnn operator
+
DNNLRnnLayerParam(int num_layer,
index_t batch_size,
index_t seq_len,
@@ -82,7 +96,9 @@
input_size(input_size),
state_size(state_size),
proj_size(proj_size),
- seq_len(seq_len) {}
+ seq_len(seq_len),
+ quantized(false),
+ enable_u8_output(false) {}
void SetDims();
};
@@ -90,10 +106,11 @@
typedef std::vector<DNNLRnnLayerParam> LayerParamVector;
struct DNNLRnnFullParam {
RNNParam default_param;
+ DNNLRnnParam dnnl_param;
LayerParamVector layer_params;
};
-DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param,
+DNNLRnnFullParam DNNLRnnFullParamParser(const nnvm::NodeAttrs& attrs,
const index_t seq_len,
const index_t batch_size,
const index_t input_size);
@@ -105,7 +122,7 @@
// The memory buffer in NDArray life-cycle
NDArray workspace_;
// This points to the memory buffer from a NDArray
- char* curr_mem;
+ char* curr_mem = nullptr;
// The total bytes of the workspace of a DNNLRnnOp
size_t mem_size = 0;
// The current available memory bytes
@@ -121,7 +138,7 @@
* \param size byte number
* \param ctx Context of device enviroment
*/
- void Init(dim_t size, const Context& ctx);
+ void Init(const dim_t size, const Context& ctx);
// Return the bytes number of the buffer
const size_t Size() {
@@ -135,6 +152,8 @@
dnnl::memory* Alloc(const dnnl::memory::desc& md);
};
+typedef std::shared_ptr<dnnl::primitive_attr> shared_dnnl_attr_t;
+
/*
* Rnn Primitive.
*/
@@ -144,15 +163,15 @@
* lstm_forward, lbr_gru_forward, vanilla_rnn_forward
*/
template <typename rnn_fwd, typename... Args>
- static RnnPrimitive Create(Args&&... args) {
+ static RnnPrimitive Create(const shared_dnnl_attr_t attr, Args&&... args) {
RnnPrimitive rnn_fwd_prim;
auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
rnn_fwd_prim.fwd_pd_.reset(
- new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()),
- [](typename rnn_fwd::primitive_desc* pd) {
- delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
- });
+ new typename rnn_fwd::primitive_desc(
+ fwd_desc, attr ? *attr : dnnl::primitive_attr(), CpuEngine::Get()->get_engine()),
+ [](void* pd) { delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd); });
auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
+ rnn_fwd_prim.attr_ = attr;
rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
rnn_fwd_prim.weights_proj_desc_ = fwd_pd->weights_projection_desc();
@@ -164,6 +183,7 @@
}
RnnPrimitive() {
+ this->attr_ = nullptr;
this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = dnnl::memory::desc();
@@ -173,6 +193,7 @@
}
RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
+ this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
@@ -183,6 +204,7 @@
RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
+ this->attr_ = rnn_fwd_prim.attr_;
this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
@@ -217,9 +239,14 @@
return workspace_desc_;
}
+ const dnnl::primitive_attr& GetPrimAttr() const {
+ return *attr_;
+ }
+
private:
std::shared_ptr<void> fwd_pd_;
std::shared_ptr<dnnl::primitive> primitive_;
+ shared_dnnl_attr_t attr_;
dnnl::memory::desc weights_layer_desc_;
dnnl::memory::desc weights_iter_desc_;
dnnl::memory::desc weights_proj_desc_;
@@ -229,7 +256,8 @@
RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
- const NDArray& params);
+ const NDArray& params,
+ const shared_dnnl_attr_t attr = nullptr);
/*
* Use this to manage memory and primitive of DNNL RNN forward inference.
@@ -240,11 +268,12 @@
const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
- const NDArray& params)
+ const NDArray& params,
+ const shared_dnnl_attr_t attr = nullptr)
: ctx_(ctx),
initialized_(false),
param_(layer_param),
- fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {}
+ fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {}
void SetNewDataMem(void* x,
void* hx,
@@ -263,6 +292,10 @@
return fwd_inf_.GetPrim();
}
+ void ResetFwd(const NDArray& data, const NDArray& params, const shared_dnnl_attr_t& attr) {
+ fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr);
+ }
+
const size_t GetSize() const {
const size_t size = fwd_inf_.GetLayerDesc().get_size() + fwd_inf_.GetIterDesc().get_size() +
fwd_inf_.GetProjDesc().get_size();
@@ -482,13 +515,13 @@
*/
class DNNLRnnOp {
public:
- explicit DNNLRnnOp(const RNNParam& param,
+ explicit DNNLRnnOp(const nnvm::NodeAttrs& attrs,
const int seq_len,
const int batch_size,
const int input_size)
: initialized_(false),
weights_version_(0),
- full_param_(DNNLRnnFullParamParser(param, seq_len, batch_size, input_size)) {}
+ full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {}
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
diff --git a/src/operator/nn/dnnl/dnnl_rnn.cc b/src/operator/nn/dnnl/dnnl_rnn.cc
index 0d65eb9..bdda9b5 100644
--- a/src/operator/nn/dnnl/dnnl_rnn.cc
+++ b/src/operator/nn/dnnl/dnnl_rnn.cc
@@ -33,6 +33,8 @@
namespace mxnet {
namespace op {
+DMLC_REGISTER_PARAMETER(DNNLRnnParam);
+
inline int GetRnnGatesNum(int mode) {
switch (mode) {
case rnn_enum::kLstm:
@@ -88,13 +90,28 @@
reserve_size = 0;
}
-DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param,
+DNNLRnnFullParam DNNLRnnFullParamParser(const NodeAttrs& attrs,
const index_t seq_len,
const index_t batch_size,
const index_t input_size) {
+ const RNNParam& rnn_param = nnvm::get<RNNParam>(attrs.parsed);
DNNLRnnFullParam full_param;
full_param.default_param = rnn_param;
- const int state_size = rnn_param.state_size;
+ try {
+ full_param.dnnl_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown);
+ } catch (const dmlc::ParamError& e) {
+ std::ostringstream os;
+ os << e.what();
+ os << ", in operator " << attrs.op->name << "("
+ << "name=\"" << attrs.name << "\"";
+ for (const auto& k : attrs.dict) {
+ os << ", " << k.first << "=\"" << k.second << "\"";
+ }
+ os << ")";
+ throw dmlc::ParamError(os.str());
+ }
+
+ const int state_size = rnn_param.state_size;
const int proj_size =
rnn_param.projection_size.has_value() ? rnn_param.projection_size.value() : -1;
const int iter_size =
@@ -135,15 +152,20 @@
false);
}
- // Set dims, workspace size, and state_outputs flag
+ // Set dims, workspace size, state_outputs, quantized and enable_u8_output flag
for (auto& layer_param : layer_params) {
layer_param.SetDims();
- layer_param.state_outputs = rnn_param.state_outputs;
+ layer_param.state_outputs = rnn_param.state_outputs;
+ layer_param.quantized = full_param.dnnl_param.quantized;
+ layer_param.enable_u8_output = true;
}
+ // Quantized RNN operator produces kFloat32 outputs.
+ if (full_param.dnnl_param.quantized)
+ layer_params.back().enable_u8_output = false;
return full_param;
}
-void DNNLRnnMemMgr::Init(dim_t size, const Context& ctx) {
+void DNNLRnnMemMgr::Init(const dim_t size, const Context& ctx) {
workspace_ = NDArray(TShape({size}), ctx, false, mshadow::kUint8);
if (workspace_.data().dptr_ == nullptr)
LOG(FATAL) << "oneDNN RNN operator memory allocation error.";
@@ -178,39 +200,48 @@
RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param,
const bool is_train,
const NDArray& data,
- const NDArray& params) {
+ const NDArray& params,
+ const shared_dnnl_attr_t attr) {
using namespace dnnl;
- using tag = dnnl::memory::format_tag;
- const int mode = layer_param.mode;
- memory::data_type data_type = get_dnnl_type(data.dtype());
- memory::data_type weight_type = get_dnnl_type(params.dtype());
+ using tag = dnnl::memory::format_tag;
+ const int mode = layer_param.mode;
+ memory::data_type src_layer_dtype = get_dnnl_type(data.dtype());
+ memory::data_type iter_dtype = get_dnnl_type(mshadow::kFloat32);
+ memory::data_type weight_dtype =
+ get_dnnl_type(layer_param.quantized ? mshadow::kInt8 : params.dtype());
+ memory::data_type bias_dtype = get_dnnl_type(mshadow::kFloat32);
+ memory::data_type dst_layer_dtype =
+ get_dnnl_type((layer_param.quantized && layer_param.enable_u8_output) ? mshadow::kUint8 :
+ mshadow::kFloat32);
+
const prop_kind prop = is_train ? prop_kind::forward_training : prop_kind::forward_inference;
const rnn_direction dnnl_rnn_direction = layer_param.bidirectional ?
rnn_direction::bidirectional_concat :
rnn_direction::unidirectional;
- auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc);
- auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any);
- auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any);
- auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo);
- auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc);
- auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc);
- auto src_cell_desc = memory::desc(layer_param.cell_dims, data_type, tag::ldnc);
+ auto src_layer_desc = memory::desc(layer_param.src_dims, src_layer_dtype, tag::tnc);
+ auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_dtype, tag::any);
+ auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_dtype, tag::any);
+ auto bias_desc = memory::desc(layer_param.bias_dims, bias_dtype, tag::ldgo);
+ auto dst_layer_desc = memory::desc(layer_param.dst_dims, dst_layer_dtype, tag::tnc);
+ auto src_state_desc = memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc);
+ auto src_cell_desc = memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc);
auto weight_peep_desc = memory::desc();
auto weight_proj_desc = layer_param.proj_size > 0 ?
- memory::desc(layer_param.weight_proj_dims, weight_type, tag::any) :
+ memory::desc(layer_param.weight_proj_dims, weight_dtype, tag::any) :
memory::desc();
auto dst_state_desc = layer_param.state_outputs ?
- memory::desc(layer_param.state_dims, data_type, tag::ldnc) :
+ memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc) :
memory::desc();
auto dst_cell_desc = layer_param.state_outputs ?
- memory::desc(layer_param.cell_dims, data_type, tag::ldnc) :
+ memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc) :
memory::desc();
auto fwd = RnnPrimitive();
switch (mode) {
case rnn_enum::kLstm:
- fwd = RnnPrimitive::Create<lstm_forward>(prop,
+ fwd = RnnPrimitive::Create<lstm_forward>(attr,
+ prop,
dnnl_rnn_direction,
src_layer_desc,
src_state_desc,
@@ -225,7 +256,8 @@
dst_cell_desc);
break;
case rnn_enum::kGru:
- fwd = RnnPrimitive::Create<lbr_gru_forward>(prop,
+ fwd = RnnPrimitive::Create<lbr_gru_forward>(attr,
+ prop,
dnnl_rnn_direction,
src_layer_desc,
src_state_desc,
@@ -238,6 +270,7 @@
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
fwd = RnnPrimitive::Create<vanilla_rnn_forward>(
+ attr,
prop,
mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu,
dnnl_rnn_direction,
@@ -449,11 +482,19 @@
auto& cpu_engine = CpuEngine::Get()->get_engine();
dnnl_args_map_t& args = net_args_;
+ int src_dtype = dtype;
+ int dst_dtype = dtype;
+ if (param_.quantized) {
+ src_dtype = mshadow::kUint8;
+ if (param_.enable_u8_output)
+ dst_dtype = mshadow::kUint8;
+ }
+
RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME);
// Set various data memory
- RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, dtype);
- RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dtype);
+ RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, src_dtype);
+ RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dst_dtype);
RNN_FWD_SET(SRC_ITER, param_.state_dims, format_tag::ldnc, hx, dtype);
if (param_.state_outputs) {
@@ -495,10 +536,25 @@
* with primitive-prefered format.
*/
void DNNLRnnForward::ReorderWeights() {
- DNNLMemoryReorder(*weights_layer_r_, *weights_layer_);
- DNNLMemoryReorder(*weights_iter_r_, *weights_iter_);
- if (param_.proj_size > 0)
- DNNLMemoryReorder(*weights_proj_r_, *weights_proj_);
+ if (param_.quantized) {
+ const dnnl::primitive_attr& attr = this->fwd_inf_.GetPrimAttr();
+ auto ReorderWithAttr = [&](dnnl::memory& src, dnnl::memory& dst) {
+ auto reorder_pd = dnnl::reorder::primitive_desc(src, dst, attr);
+ dnnl_args_map_t net_args;
+ net_args[DNNL_ARG_SRC] = src;
+ net_args[DNNL_ARG_DST] = dst;
+ DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), net_args);
+ };
+ ReorderWithAttr(*weights_layer_r_, *weights_layer_);
+ ReorderWithAttr(*weights_iter_r_, *weights_iter_);
+ if (param_.proj_size > 0)
+ ReorderWithAttr(*weights_proj_r_, *weights_proj_);
+ } else {
+ DNNLMemoryReorder(*weights_layer_r_, *weights_layer_);
+ DNNLMemoryReorder(*weights_iter_r_, *weights_iter_);
+ if (param_.proj_size > 0)
+ DNNLMemoryReorder(*weights_proj_r_, *weights_proj_);
+ }
}
void AdjustGruGateOrder(char* weight,
@@ -573,7 +629,7 @@
*/
void DNNLRnnForward::SetWeightsMem(void* w_ptr, void* b_ptr, const bool is_train, const int dtype) {
using format_tag = dnnl::memory::format_tag;
- auto dnnl_dtype = get_dnnl_type(dtype);
+ const auto dnnl_dtype = get_dnnl_type(dtype);
const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
const size_t buffer_bytes =
@@ -702,7 +758,7 @@
// in forward training path, we use plain memory (ldxxx) as the space for weights and
// their gradients. Then, forward training primitives could fetch them from the scope
// of forward inference. And from there, we don't need to reorder the plain memory to
- // the optimal rnn-packed memory for forward inference.
+ // the optimal rnn-packed memory for forward inference
ReorderWeights();
initialized_ = true;
}
@@ -764,6 +820,19 @@
const std::vector<NDArray>& outputs) {
using format_tag = dnnl::memory::format_tag;
+ // Get the bytes of a real type
+ const NDArray& weights = inputs[rnn_enum::kParams];
+ int dtype = weights.dtype();
+ size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
+ const RNNParam& default_param = full_param_.default_param;
+ const size_t weights_size =
+ weights.data().Size() - GetRnnBiasSize(default_param.num_layers,
+ default_param.state_size,
+ default_param.bidirectional + 1,
+ default_param.mode);
+ char* weights_ptr = static_cast<char*>(weights.data().dptr_);
+ char* bias_ptr = weights_ptr + weights_size * dtype_bytes;
+
// In the `autograd.record()` context, RNNOp is required to run into
// `forward_training` mode.
const bool is_training = (op_ctx.is_train || op_ctx.need_grad);
@@ -772,7 +841,7 @@
if (fwd_inf_vec_.size() < num_fusion) {
for (auto& layer_param : full_param_.layer_params) {
fwd_inf_vec_.emplace_back(
- ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]);
+ ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], nullptr);
}
}
@@ -783,19 +852,6 @@
}
}
- // Get the bytes of a real type
- const NDArray& weights = inputs[rnn_enum::kParams];
- int dtype = weights.dtype();
- size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
-
- const RNNParam& default_param = full_param_.default_param;
- char* weights_ptr = static_cast<char*>(weights.data().dptr_);
- char* bias_ptr =
- weights_ptr + (weights.data().Size() - GetRnnBiasSize(default_param.num_layers,
- default_param.state_size,
- default_param.bidirectional + 1,
- default_param.mode)) *
- dtype_bytes;
for (auto& fwd_layer : fwd_inf_vec_) {
size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes;
size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes;
@@ -819,7 +875,7 @@
CHECK_EQ(num_fusion, fwd_inf_vec_.size())
<< "Layer vector's size has a different value than the number of fusion.";
if (dst_.size() < num_fusion - 1) {
- int data_dtype = outputs[rnn_enum::kOut].dtype();
+ const int data_dtype = outputs[rnn_enum::kOut].dtype();
const size_t data_dbytes = mshadow::mshadow_sizeof(data_dtype);
mgr_.Init((outputs[rnn_enum::kOut].data().Size() * data_dbytes + kDNNLAlign) * (num_fusion - 1),
op_ctx.run_ctx.ctx);
@@ -1121,7 +1177,7 @@
}
// Get data type
- int data_dtype = inputs[rnn_enum::kData].dtype();
+ int data_dtype = outputs[rnn_enum::kOut].dtype();
// Get temporary memory for output, state_out, statecell_out
const int num_layers = default_param.num_layers;
const int seq_length = default_param.seq_length_;
diff --git a/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h
new file mode 100644
index 0000000..9bbbd2d
--- /dev/null
+++ b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_quantize_asym-inl.h
+ * \brief implementation of asymmetric quantize operation using DNNL
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <vector>
+#include "operator/nn/dnnl/dnnl_base-inl.h"
+#include "operator/quantization/quantize_asym-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class DNNLQuantizeAsymOp {
+ public:
+ explicit DNNLQuantizeAsymOp(const nnvm::NodeAttrs& attrs)
+ : param_(nnvm::get<QuantizeAsymParam>(attrs.parsed)) {}
+
+ void Forward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+ private:
+ QuantizeAsymParam param_;
+ bool initialized_{false};
+ float cached_scale_{0.f};
+ float cached_shift_{0.f};
+ dnnl::memory::desc o_desc_;
+ dnnl_args_map_t args_;
+ std::shared_ptr<dnnl::reorder> fwd_pd_;
+};
+
+void DNNLQuantizeAsymOp::Forward(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ using mshadow::red::limits::MaxValue;
+ using mshadow::red::limits::MinValue;
+ NDArray in_buffer = inputs[0];
+ float scale = 0.f;
+ float shift = 0.f;
+
+ // Pass through quantized data
+ if (inputs[0].dtype() == mshadow::kUint8) {
+ *outputs[1].data().dptr<float>() = 1;
+ *outputs[2].data().dptr<float>() = 0;
+ if (req[0] != kWriteInplace) {
+ const_cast<NDArray&>(outputs[0]).CopyFrom(*inputs[0].GetDNNLData());
+ DNNLStream::Get()->Submit();
+ }
+ } else {
+ in_buffer = inputs[0].Reorder2Default();
+ const dnnl::memory* i_mem = in_buffer.GetDNNLData();
+ float* in_ptr = in_buffer.data().dptr<float>();
+ const int nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ if (inputs[0].dtype() == mshadow::kInt8) {
+ *outputs[1].data().dptr<float>() = 1;
+ *outputs[2].data().dptr<float>() = 128;
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(in_buffer.shape().Size()); ++i) {
+ in_ptr[i] += 128.0f;
+ }
+ } else if (inputs[0].dtype() == mshadow::kFloat32) {
+ if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) {
+ scale =
+ MaxValue<uint8_t>() / (param_.max_calib_range.value() - param_.min_calib_range.value());
+ shift = MaxValue<uint8_t>() - param_.max_calib_range.value() * scale;
+ } else {
+ float data_min = mshadow::red::limits::MaxValue<float>();
+ float data_max = mshadow::red::limits::MinValue<float>();
+ std::vector<float> data_maxs(nthreads, data_max);
+ std::vector<float> data_mins(nthreads, data_min);
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(in_buffer.shape().Size()); i++) {
+ int tid = omp_get_thread_num();
+ if (in_ptr[i] > data_maxs[tid])
+ data_maxs[tid] = in_ptr[i];
+ if (in_ptr[i] < data_mins[tid])
+ data_mins[tid] = in_ptr[i];
+ }
+ for (index_t i = 0; i < nthreads; i++) {
+ if (data_maxs[i] > data_max)
+ data_max = data_maxs[i];
+ if (data_mins[i] < data_min)
+ data_min = data_mins[i];
+ }
+ scale = MaxValue<uint8_t>() / (data_max - data_min);
+ shift = MaxValue<uint8_t>() - data_max * scale;
+ }
+
+ if (initialized_ && (cached_scale_ != scale || cached_shift_ != shift))
+ initialized_ = false;
+ }
+
+ *outputs[1].data().dptr<float>() = scale;
+ *outputs[2].data().dptr<float>() = shift;
+
+ if (!initialized_) {
+ cached_scale_ = scale;
+ cached_shift_ = shift;
+ dnnl::primitive_attr attr;
+ attr.set_rnn_data_qparams(scale, shift);
+ const dnnl::engine& cpu_engine = mxnet::CpuEngine::Get()->get_engine();
+ const dnnl::memory::desc& i_desc = i_mem->get_desc();
+ o_desc_ = i_desc;
+ o_desc_.data.data_type = get_dnnl_type_t(outputs[0].dtype());
+ dnnl::reorder::primitive_desc reorder_pd(cpu_engine, i_desc, cpu_engine, o_desc_, attr);
+ fwd_pd_ = std::make_shared<dnnl::reorder>(reorder_pd);
+ initialized_ = true;
+ }
+ dnnl_output_t o_mem = CreateDNNLMem(outputs[0], o_desc_, req[0]);
+ args_[DNNL_ARG_FROM] = *i_mem;
+ args_[DNNL_ARG_TO] = *o_mem.second;
+ DNNLStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
+ CommitOutput(outputs[0], o_mem);
+ DNNLStream::Get()->Submit();
+ }
+}
+
+void DNNLQuantizeAsymForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ if (inputs[0].shape().ndim() == 3 && inputs[0].dtype() == mshadow::kFloat32) {
+ DNNLQuantizeAsymOp& op = state_ptr.get_state<DNNLQuantizeAsymOp>();
+ op.Forward(ctx, inputs, req, outputs);
+ } else {
+ FallBackCompute(QuantizeAsymForward<cpu>, state_ptr, ctx, inputs, req, outputs);
+ }
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h
new file mode 100644
index 0000000..cdd5417
--- /dev/null
+++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_quantized_rnn-inl.h
+ * \brief Common functions for quantized recurrent neural network
+ * \author Zixuan Wei
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <vector>
+#include "operator/nn/dnnl/dnnl_rnn-inl.h"
+#include "operator/rnn-inl.h"
+#include "operator/quantization/quantized_rnn-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class DNNLQuantizedRnnOp {
+ public:
+ explicit DNNLQuantizedRnnOp(const nnvm::NodeAttrs& attrs,
+ const int seq_len,
+ const int batch_size,
+ const int input_size)
+ : initialized_(false),
+ weights_ver_(0),
+ rnn_attr_(new dnnl::primitive_attr),
+ full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {}
+
+ void Forward(const OpContext& op_ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+
+ private:
+ bool initialized_;
+ size_t weights_ver_;
+ shared_dnnl_attr_t rnn_attr_;
+ DNNLRnnFullParam full_param_;
+ DNNLRnnMemMgr mgr_;
+ std::vector<DNNLRnnForward> fwd_inf_vec_; // forward inference layers
+
+ // Used to store the intermediate results of multi-layer
+ std::vector<dnnl::memory*> dst_;
+ // According to
+ // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html, the
+ // non-symmetric quantization is assumed by LSTM primitive. Namely, the
+ // formula is:
+ // data_f32 = (data_u8 - shift) / scale
+ float cached_data_shift_{0.0};
+ float cached_data_scale_{0.0};
+ void Init(const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs);
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
+#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc
new file mode 100644
index 0000000..73393d9
--- /dev/null
+++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_quantized_rnn.cc
+ * \brief Common functions for quantized recurrent neural network
+ * \author Zixuan Wei
+ */
+
+#if MXNET_USE_ONEDNN == 1
+
+#include "operator/quantization/quantization_utils.h"
+#include "operator/quantization/dnnl/dnnl_quantized_rnn-inl.h"
+
+namespace mxnet {
+namespace op {
+
+std::vector<float> GetDNNLRnnWeightsQParams(const DNNLRnnFullParam& full_param, float* w_ptr) {
+ const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ const int num_gates = 4;
+ const RNNParam& default_param = full_param.default_param;
+ const LayerParamVector& layer_params = full_param.layer_params;
+
+ const DNNLRnnLayerParam& layer_param0 = layer_params.at(0);
+ const size_t w_size0 = layer_param0.single_w_size;
+ const size_t wx_size0 = num_gates * layer_param0.state_size * layer_param0.input_size;
+ const size_t wh_size0 = num_gates * layer_param0.state_size * layer_param0.state_size;
+
+ int directions = 1;
+ float* wx = w_ptr;
+ float* wh = wx + wx_size0;
+ float* fake_wx = wx;
+ float* fake_wh = wh;
+
+ std::vector<float> wx_goi_max;
+ std::vector<float> wh_goi_max;
+ if (default_param.bidirectional) {
+ directions = 2;
+ wx_goi_max.resize(wx_size0);
+ wh_goi_max.resize(wh_size0);
+ fake_wx = wx_goi_max.data();
+ fake_wh = wh_goi_max.data();
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(wx_size0); ++i) {
+ fake_wx[i] = MaxAbs(wx[i], wx[i + w_size0]);
+ }
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(wh_size0); ++i) {
+ fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]);
+ }
+ }
+ std::vector<float> w_max(num_gates * layer_param0.state_size, 0.0);
+ const index_t input_size = layer_param0.input_size; // input
+ const index_t state_size = layer_param0.state_size; // state
+ const index_t gates_nblks = num_gates * layer_param0.state_size; // gates * state
+ for (index_t go = 0; go < gates_nblks; ++go) {
+ float tmp_max = w_max[go];
+ for (index_t i = 0; i < input_size; ++i) {
+ tmp_max = MaxAbs(fake_wx[go * input_size + i], tmp_max);
+ }
+ for (index_t i = 0; i < state_size; ++i) {
+ tmp_max = MaxAbs(fake_wh[go * state_size + i], tmp_max);
+ }
+ w_max[go] = tmp_max;
+ }
+ wx += layer_param0.single_w_size * directions;
+ wh += layer_param0.single_w_size * directions;
+
+ std::vector<float> goi_max(wh_size0, 0.0);
+ for (size_t lyr = 1; lyr < layer_params.size(); ++lyr) {
+ const DNNLRnnLayerParam& layer_param = layer_params.at(lyr);
+ const int weight_nblks = layer_param.num_layer * directions;
+ for (int blk = 0; blk < weight_nblks; ++blk) {
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(wh_size0); ++i) {
+ goi_max[i] = MaxAbs(wx[i], wh[i]);
+ }
+ for (index_t go = 0; go < gates_nblks; ++go) {
+ float tmp = w_max[go];
+// NOTES: min/max reductions were supported since OpenMP 3.1, which was
+// released in Jul 2011 (hence the version number).
+#if _OPENMP >= 201107
+#pragma omp parallel for reduction(max : tmp) num_threads(nthreads)
+#endif
+ for (index_t i = 0; i < state_size; ++i) {
+ tmp = Max(goi_max[go * state_size + i], tmp);
+ }
+ w_max[go] = tmp;
+ }
+ }
+ wx += layer_param.single_w_size * directions;
+ wh = wx + wh_size0;
+ }
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(w_max.size()); ++i) {
+ w_max[i] = mshadow::red::limits::MaxValue<int8_t>() / w_max[i];
+ }
+ return w_max;
+}
+
+void DNNLQuantizedRnnOp::Init(const OpContext& op_ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ using format_tag = dnnl::memory::format_tag;
+
+ // Get the bytes of a real type
+ const Context& ctx = op_ctx.run_ctx.ctx;
+ const NDArray& weights = inputs[rnn_enum::kParams];
+ int dtype = weights.dtype();
+ int weights_dtype = weights.dtype();
+ size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
+ const RNNParam& default_param = full_param_.default_param;
+ const size_t weights_size =
+ weights.data().Size() - GetRnnBiasSize(default_param.num_layers,
+ default_param.state_size,
+ default_param.bidirectional + 1,
+ default_param.mode);
+ char* weights_ptr = static_cast<char*>(weights.data().dptr_);
+ char* bias_ptr = weights_ptr + weights_size * dtype_bytes;
+
+ // In the `autograd.record()` context, RNNOp is required to run into
+ // `forward_training` mode.
+
+ const size_t num_fusion = full_param_.layer_params.size();
+ if (fwd_inf_vec_.size() < num_fusion) {
+ size_t buffer_size = 0; // Element number, instead of bytes, in the buffer
+ for (auto& layer_param : full_param_.layer_params) {
+ buffer_size += layer_param.workspace_size + layer_param.reserve_size;
+ }
+ buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1);
+ buffer_size += kDNNLAlign * num_fusion * 5; // Add margin for alignment
+
+ for (auto& layer_param : full_param_.layer_params) {
+ fwd_inf_vec_.emplace_back(
+ ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], rnn_attr_);
+ buffer_size += fwd_inf_vec_.back().GetSize();
+ }
+ mgr_.Init(buffer_size, ctx);
+ }
+
+ for (auto& fwd_layer : fwd_inf_vec_) {
+ size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes;
+ size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes;
+ size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1;
+ size_t layer_weights_bytes = single_w_bytes * directions;
+ size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias
+
+ if (!fwd_layer.IsInitialized())
+ fwd_layer.SetWeightsMem(weights_ptr, bias_ptr, false, weights_dtype);
+ weights_ptr += layer_weights_bytes;
+ bias_ptr += layer_bias_bytes;
+ }
+
+ CHECK_EQ(num_fusion, fwd_inf_vec_.size())
+ << "Layer vector's size has a different value than the number of fusion.";
+ if (dst_.size() < num_fusion - 1) {
+ const int data_dtype = outputs[rnn_enum::kOut].dtype();
+ // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate
+ // results of the multiple fused layers. And for the result of the last
+ // fused layer, `outputs[rnn_enum::kOut]` could provide the space. Hence,
+ // `forward_inf_vec_.back()` is excluded when allocates the spaces for
+ // intermediate results.
+ for (std::vector<DNNLRnnForward>::const_iterator fwd = fwd_inf_vec_.begin();
+ fwd != fwd_inf_vec_.end() - 1;
+ ++fwd)
+ dst_.push_back(
+ mgr_.Alloc({fwd->GetParam().dst_dims, get_dnnl_type(data_dtype), format_tag::tnc}));
+ }
+
+ initialized_ = true;
+}
+
+template <typename DNNLRnnX>
+inline void RegisterDNNLRnn(DNNLRnnX const& rnn) {
+ DNNLStream::Get()->RegisterPrimArgs(rnn.GetFwd(), rnn.GetArgsMap());
+}
+
+template <>
+inline void RegisterDNNLRnn(DNNLRnnBackward const& rnn) {
+ DNNLStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap());
+ rnn.SetNativeWeightsGrads();
+}
+
+void DNNLQuantizedRnnOp::Forward(const OpContext& op_ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ TmpMemMgr::Get()->Init(op_ctx.requested[0]);
+
+ const RNNParam& default_param = full_param_.default_param;
+ const uint32_t num_base_inputs = GetRnnNumInputs(default_param);
+ float data_scale = inputs[num_base_inputs + quantized_rnn::kDataScale].data().dptr<float>()[0];
+ float data_shift = inputs[num_base_inputs + quantized_rnn::kDataShift].data().dptr<float>()[0];
+
+ const bool need_reset_weight = (!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) &&
+ weights_ver_ != inputs[rnn_enum::kParams].version()) ?
+ true :
+ false;
+ const NDArray& weights = inputs.at(rnn_enum::kParams);
+ float* weights_ptr = weights.data().dptr<float>();
+ if (!initialized_ || fwd_inf_vec_.empty()) {
+ weights_ver_ = inputs[rnn_enum::kParams].version();
+ cached_data_scale_ = data_scale;
+ cached_data_shift_ = data_shift;
+ rnn_attr_->set_rnn_data_qparams(data_scale, data_shift);
+ if (need_reset_weight || fwd_inf_vec_.empty())
+ rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4),
+ GetDNNLRnnWeightsQParams(full_param_, weights_ptr));
+ }
+
+ // Initialize weights version
+ if (!initialized_ && weights_ver_ == 0) {
+ weights_ver_ = inputs[rnn_enum::kParams].version();
+ cached_data_scale_ = data_scale;
+ cached_data_shift_ = data_shift;
+ }
+
+ if (!fwd_inf_vec_.empty() &&
+ ((cached_data_scale_ != data_scale || cached_data_shift_ != data_shift))) {
+ initialized_ = false;
+ weights_ver_ = inputs[rnn_enum::kParams].version();
+ cached_data_scale_ = data_scale;
+ cached_data_shift_ = data_shift;
+ }
+
+ // Check if weights NDArray was changed. If so, reset initialized_
+ if (fwd_inf_vec_.size() > 0 && weights_ver_ != inputs[rnn_enum::kParams].version()) {
+ initialized_ = false;
+ for (auto& fwd : fwd_inf_vec_)
+ fwd.Reset();
+ weights_ver_ = inputs[rnn_enum::kParams].version();
+ cached_data_scale_ = data_scale;
+ cached_data_shift_ = data_shift;
+ }
+
+ if (!initialized_ || fwd_inf_vec_.empty()) {
+ Init(op_ctx, inputs, req, outputs);
+ }
+
+ // Get data type
+ int data_dtype = outputs[rnn_enum::kOut].dtype();
+ // Get temporary memory for output, state_out, statecell_out
+ const int num_layers = default_param.num_layers;
+ const int seq_length = default_param.seq_length_;
+ const int batch_size = default_param.batch_size_;
+ const int state_size = default_param.state_size;
+ const int directions = default_param.bidirectional ? 2 : 1;
+ dnnl::memory::desc dst_desc({seq_length, batch_size, directions * state_size},
+ get_dnnl_type(data_dtype),
+ dnnl::memory::format_tag::tnc);
+ dnnl::memory::desc state_desc({num_layers, directions, batch_size, state_size},
+ get_dnnl_type(data_dtype),
+ dnnl::memory::format_tag::ldnc);
+ auto out_mem = CreateDNNLMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]);
+ dnnl_output_t stateout_mem;
+ dnnl_output_t statecellout_mem;
+
+ // Get input & output NDArray
+ char* src = static_cast<char*>(inputs[rnn_enum::kData].data().dptr_);
+ char* src_state = static_cast<char*>(inputs[rnn_enum::kState].data().dptr_);
+ char* dst = static_cast<char*>(out_mem.second->get_data_handle());
+ char* dst_state = nullptr; // Output state
+ char* src_state_cell = nullptr; // Used in LSTM for cell state
+ char* dst_state_cell = nullptr; // Used in LSTM for cell state
+ const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
+ default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
+
+ if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) {
+ stateout_mem =
+ CreateDNNLMem(outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]);
+ dst_state = static_cast<char*>(stateout_mem.second->get_data_handle());
+ }
+
+ if (default_param.mode == rnn_enum::kLstm) {
+ src_state_cell = static_cast<char*>(inputs[rnn_enum::kStateCell].data().dptr_);
+ if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) {
+ statecellout_mem =
+ CreateDNNLMem(outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]);
+ dst_state_cell = static_cast<char*>(statecellout_mem.second->get_data_handle());
+ }
+ }
+
+ if (fwd_inf_vec_.size() == 1) {
+ fwd_inf_vec_.front().SetNewDataMem(
+ src, src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype);
+ } else {
+ CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error.";
+ size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
+ default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
+
+ // Set input data memory for the first layer. This stores intermediate
+ // output results in this->xxx, used as the source input of the next layer.
+ fwd_inf_vec_.front().SetNewDataMem(src,
+ src_state,
+ src_state_cell,
+ this->dst_.front()->get_data_handle(),
+ dst_state,
+ dst_state_cell,
+ data_dtype);
+ // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ...
+ for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) {
+ src_state += cell_bytes;
+ if (src_state_cell)
+ src_state_cell += cell_bytes;
+ if (dst_state)
+ dst_state += cell_bytes;
+ if (dst_state_cell)
+ dst_state_cell += cell_bytes;
+ fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(),
+ src_state,
+ src_state_cell,
+ this->dst_.at(lyr)->get_data_handle(),
+ dst_state,
+ dst_state_cell,
+ data_dtype);
+ }
+ // Set output data memory for the last layer.
+ src_state += cell_bytes;
+ if (src_state_cell)
+ src_state_cell += cell_bytes;
+ if (dst_state)
+ dst_state += cell_bytes;
+ if (dst_state_cell)
+ dst_state_cell += cell_bytes;
+ fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(),
+ src_state,
+ src_state_cell,
+ dst,
+ dst_state,
+ dst_state_cell,
+ data_dtype);
+ }
+
+ for (auto& inf_lyr : fwd_inf_vec_)
+ RegisterDNNLRnn(inf_lyr);
+
+ CommitOutput(outputs[rnn_enum::kOut], out_mem);
+ if (default_param.state_outputs) {
+ CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem);
+ if (default_param.mode == rnn_enum::kLstm)
+ CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem);
+ }
+ DNNLStream::Get()->Submit();
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/quantization/quantize_asym-inl.h b/src/operator/quantization/quantize_asym-inl.h
new file mode 100644
index 0000000..3aa44c4
--- /dev/null
+++ b/src/operator/quantization/quantize_asym-inl.h
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantize_asym-inl.h
+ * \brief implementation of asymmetric quantize operation
+ */
+#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mshadow/tensor.h>
+#include <mxnet/operator_util.h>
+#include <vector>
+
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../tensor/broadcast_reduce_op.h"
+#include "./quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+struct QuantizeAsymParam : public dmlc::Parameter<QuantizeAsymParam> {
+ dmlc::optional<float> min_calib_range;
+ dmlc::optional<float> max_calib_range;
+
+ DMLC_DECLARE_PARAMETER(QuantizeAsymParam) {
+ DMLC_DECLARE_FIELD(min_calib_range)
+ .set_default(dmlc::optional<float>())
+ .describe(
+ "The minimum scalar value in the form of float32. If "
+ "present, it will be used to "
+ "quantize the fp32 data.");
+ DMLC_DECLARE_FIELD(max_calib_range)
+ .set_default(dmlc::optional<float>())
+ .describe(
+ "The maximum scalar value in the form of float32. If "
+ "present, it will be used to "
+ "quantize the fp32 data.");
+ }
+};
+
+// quantize float to uint8_t
+struct quantize_asymmetric {
+ template <typename DstDType, typename SrcDType>
+ MSHADOW_XINLINE static void Map(int i,
+ DstDType* out,
+ float* oscale,
+ float* oshift,
+ const SrcDType* in,
+ const float scale,
+ const float shift) {
+ out[i] = static_cast<DstDType>(in[i] * scale + shift + 0.5);
+ *oscale = scale;
+ *oshift = shift;
+ }
+};
+
+template <typename xpu>
+class QuantizeAsymOp {
+ public:
+ explicit QuantizeAsymOp(const nnvm::NodeAttrs& attrs) : attrs_(attrs) {}
+
+ void Forward(const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using mshadow::red::limits::MaxValue;
+ using mshadow::red::limits::MinValue;
+
+ CHECK_EQ(outputs[0].type_flag_, mshadow::kUint8)
+ << "Asymmetric quantization only supports uint8 outputs.";
+ mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+ const int input_data_dtype = inputs[0].type_flag_;
+ if (input_data_dtype == mshadow::kUint8) {
+ *outputs[1].dptr<float>() = 1;
+ *outputs[2].dptr<float>() = 0;
+ UnaryOp::IdentityCompute<xpu>(attrs_, ctx, {inputs[0]}, req, outputs);
+ } else if (input_data_dtype == mshadow::kInt8) {
+ const float scale = 1;
+ const float shift = 128;
+ Kernel<quantize_asymmetric, xpu>::Launch(s,
+ outputs[0].Size(),
+ outputs[0].dptr<uint8_t>(),
+ outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(),
+ inputs[0].dptr<int8_t>(),
+ scale,
+ shift);
+ } else if (input_data_dtype == mshadow::kFloat32) {
+ const QuantizeAsymParam& param = nnvm::get<QuantizeAsymParam>(attrs_.parsed);
+ if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
+ const float scale =
+ MaxValue<uint8_t>() / (param.max_calib_range.value() - param.min_calib_range.value());
+ const float shift = MaxValue<uint8_t>() - param.max_calib_range.value() * scale;
+ Kernel<quantize_asymmetric, xpu>::Launch(s,
+ outputs[0].Size(),
+ outputs[0].dptr<uint8_t>(),
+ outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(),
+ inputs[0].dptr<float>(),
+ scale,
+ shift);
+ } else {
+ mxnet::TShape src_shape, dst_shape;
+ const size_t float_bytes = sizeof(float);
+ const size_t temp_reduce_size = ConfigReduce<xpu, float>(
+ s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape);
+ Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, char>(
+ Shape1(2 * float_bytes + temp_reduce_size), s);
+ const int dev_id = ctx.run_ctx.ctx.dev_id;
+ TBlob in_min_t(
+ reinterpret_cast<float*>(temp_space.dptr_), Shape1(1), xpu::kDevMask, dev_id);
+ TBlob in_max_t(
+ reinterpret_cast<float*>(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, dev_id);
+ Tensor<xpu, 1, char> workspace(
+ temp_space.dptr_ + 2 * float_bytes, Shape1(temp_reduce_size), s);
+ broadcast::Reduce<red::minimum, 2, float, mshadow::op::identity>(
+ s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+ broadcast::Reduce<red::maximum, 2, float, mshadow::op::identity>(
+ s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
+ const float scale =
+ MaxValue<uint8_t>() / (*in_max_t.dptr<float>() - *in_min_t.dptr<float>());
+ const float shift = MaxValue<uint8_t>() - *in_max_t.dptr<float>() * scale;
+ Kernel<quantize_asymmetric, xpu>::Launch(s,
+ outputs[0].Size(),
+ outputs[0].dptr<uint8_t>(),
+ outputs[1].dptr<float>(),
+ outputs[2].dptr<float>(),
+ inputs[0].dptr<float>(),
+ scale,
+ shift);
+ }
+ } else {
+ LOG(FATAL) << "Asymmetric quantizaiton only supports int8, uint8 and "
+ "float inputs";
+ }
+ }
+
+ private:
+ nnvm::NodeAttrs attrs_;
+};
+
+template <typename xpu>
+void QuantizeAsymForward(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ QuantizeAsymOp<xpu>& op = state_ptr.get_state<QuantizeAsymOp<xpu>>();
+ op.Forward(ctx, inputs, req, outputs);
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_
diff --git a/src/operator/quantization/quantize_asym.cc b/src/operator/quantization/quantize_asym.cc
new file mode 100644
index 0000000..4cb2669
--- /dev/null
+++ b/src/operator/quantization/quantize_asym.cc
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantize_asym.cc
+ * \brief implementation of asymmetric quantize operation
+ */
+
+#include <string>
+
+#include "operator/quantization/quantize_asym-inl.h"
+#if MXNET_USE_ONEDNN == 1
+#include "operator/quantization/dnnl/dnnl_quantize_asym-inl.h"
+#endif
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(QuantizeAsymParam);
+
+inline bool QuantizeAsymShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector* in_attrs,
+ mxnet::ShapeVector* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 3U);
+
+ mxnet::TShape dshape = in_attrs->at(0);
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+ SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1));
+ SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1));
+
+ if (out_attrs->at(0).ndim() > 0) {
+ dshape[0] = out_attrs->at(0)[0];
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
+ }
+
+ return !shape_is_none(out_attrs->at(0));
+}
+
+inline bool QuantizeAsymType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 3U);
+
+ CHECK_EQ(in_attrs->at(0), mshadow::kFloat32);
+
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
+ TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32);
+
+ return !type_is_none(out_attrs->at(0));
+}
+
+bool QuantizeAsymStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ *dispatch_mode = DispatchMode::kFCompute;
+#if MXNET_USE_ONEDNN == 1
+ if (dev_mask == mshadow::cpu::kDevMask) {
+ *dispatch_mode = DispatchMode::kFComputeEx;
+ }
+#endif
+ out_attrs->at(0) = kDefaultStorage;
+ out_attrs->at(1) = kDefaultStorage;
+ out_attrs->at(2) = kDefaultStorage;
+ return true;
+}
+
+OpStatePtr CreateQuantizeAsymState(const nnvm::NodeAttrs& attrs,
+ const Context& ctx,
+ const std::vector<TShape>& in_shapes,
+ const std::vector<int>& in_types) {
+ OpStatePtr state;
+ if (ctx.dev_type == kGPU) {
+ state = OpStatePtr::Create<QuantizeAsymOp<gpu>>(attrs);
+ } else {
+#if MXNET_USE_ONEDNN == 1
+ if (in_shapes[0].ndim() == 3 && in_types[0] == mshadow::kFloat32) {
+ state = OpStatePtr::Create<DNNLQuantizeAsymOp>(attrs);
+ return state;
+ }
+#else
+ state = OpStatePtr::Create<QuantizeAsymOp<cpu>>(attrs);
+#endif
+ }
+ return state;
+}
+
+NNVM_REGISTER_OP(_contrib_quantize_asym)
+ .describe(R"code(Quantize a input tensor from float to uint8_t.
+Output `scale` and `shift` are scalar floats that specify the quantization
+parameters for the input data. The output is calculated using the following equation:
+
+`out[i] = in[i] * scale + shift + 0.5`,
+
+where `scale = uint8_range / (max_range - min_range)` and
+`shift = numeric_limits<T>::max - max_range * scale`.
+
+.. Note::
+ This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE)
+ .set_attr_parser(ParamParser<QuantizeAsymParam>)
+ .set_num_inputs(1)
+ .set_num_outputs(3)
+ .set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data"};
+ })
+ .set_attr<nnvm::FListOutputNames>("FListOutputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"output", "scale", "shift"};
+ })
+ .set_attr<mxnet::FInferShape>("FInferShape", QuantizeAsymShape)
+ .set_attr<nnvm::FInferType>("FInferType", QuantizeAsymType)
+ .set_attr<FInferStorageType>("FInferStorageType", QuantizeAsymStorageType)
+ .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+ .set_attr<FCreateOpState>("FCreateOpState", CreateQuantizeAsymState)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<bool>("TIsDNNL", true)
+ .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", DNNLQuantizeAsymForward)
+#endif
+ .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", QuantizeAsymForward<cpu>)
+ .set_attr<FNeedCalibrateInput>("FNeedCalibrateInput",
+ [](const NodeAttrs& attrs) { return std::vector<int>{0}; })
+ .set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ const QuantizeAsymParam& param =
+ nnvm::get<QuantizeAsymParam>(attrs.parsed);
+ if (param.max_calib_range.has_value() &&
+ param.max_calib_range.has_value()) {
+ return std::vector<ResourceRequest>();
+ } else {
+ return std::vector<ResourceRequest>(
+ 1, ResourceRequest::kTempSpace);
+ }
+ })
+ .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`")
+ .add_arguments(QuantizeAsymParam::__FIELDS__());
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc
index 3835f1a..a4e3086 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -288,6 +288,10 @@
static const auto& avoid_quantize_input_map =
Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
static const auto& flist_inputs = nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListInputNames");
+ static const auto& avoid_dequantize_map =
+ Op::GetAttr<mxnet::FAvoidDequantizeOutput>("FAvoidDequantizeOutput");
+ static const auto& need_asym_quantize_map =
+ Op::GetAttr<mxnet::FNeedAsymQuantizeInput>("FNeedAsymQuantizeInput");
const auto offline_params = src.GetAttr<std::unordered_set<std::string>>("offline_params");
const auto quantized_dtype = src.GetAttr<std::string>("quantized_dtype");
const auto quantize_granularity = src.GetAttr<std::string>("quantize_granularity");
@@ -331,7 +335,14 @@
if (avoid_quantize_input_map.count(node->op()) &&
avoid_quantize_input_map[node->op()](node->attrs, i, quantize_granularity)) {
new_node->inputs.emplace_back(mirror_entry);
- } else if (!quantized_node_map.count(e.node)) {
+ } else if (!quantized_node_map.count(e.node) ||
+ (avoid_dequantize_map.count(e.node->op()) &&
+ avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) {
+ // If the input of current quantized node has non-support of quantization, a quantize op
+ // is supposed to insert into the position after the input node to quantize the float
+ // input to int8/uint8 type. Also, a quantized operator with avoid-dequantize attribute
+ // can produce float outputs directly. A quantize op is necessary to convert them into
+ // int8/uint8 type as the input of current quantized node.
if (mirror_entry_map.count(e)) {
new_node->inputs.emplace_back(mirror_entry_map[e]);
} else {
@@ -354,10 +365,20 @@
new_name = node->attrs.name + "_" + e.node->attrs.name;
}
}
-
- ObjectPtr quantize_node = InsertNode(
- "_contrib_quantize_v2", new_name + suffix + "_quantize", new_node, mirror_entry);
- quantize_node->attrs.dict["out_type"] = quantized_dtype;
+ ObjectPtr quantize_node;
+ if (need_asym_quantize_map.count(node->op()) &&
+ need_asym_quantize_map[node->op()](node->attrs, i)) {
+ quantize_node = InsertNode("_contrib_quantize_asym",
+ new_name + suffix + "_quantize",
+ new_node,
+ mirror_entry);
+ } else {
+ quantize_node = InsertNode(
+ "_contrib_quantize_v2", new_name + suffix + "_quantize", new_node, mirror_entry);
+ // If current node is rnn op, the quantize op is supposed to quantize the result of
+ // pre-node to uint8, as quantized rnn op requires uint8 input.
+ quantize_node->attrs.dict["out_type"] = quantized_dtype;
+ }
quantize_node->op()->attr_parser(&(quantize_node->attrs));
mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version};
}
@@ -439,9 +460,13 @@
for (const auto& e : node->inputs) {
ObjectPtr mirror_node = mirror_map.at(e.node.get());
NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
- // if input node is quantized operator, add dequantize node
+ // If input node is quantized operator, add dequantize node. But if input node is a
+ // quantized operator with avoid-dequantize attribute, its output may be already in float
+ // type, which dosen't need a dequantize op.
if (quantized_node_map.count(e.node) &&
- (mirror_node->op() != Op::Get("_contrib_dequantize"))) {
+ mirror_node->op() != Op::Get("_contrib_dequantize") &&
+ !(avoid_dequantize_map.count(e.node->op()) &&
+ avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) {
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1 min and 1 max output from mirror node (which is
@@ -473,7 +498,9 @@
std::vector<NodeEntry> outputs;
for (const auto& e : src.outputs) {
- if (quantized_node_map.count(e.node)) {
+ if (quantized_node_map.count(e.node) &&
+ !(avoid_dequantize_map.count(e.node->op()) &&
+ avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) {
// Only insert dequantize for those Ops supports quantize and not excluded.
ObjectPtr mirror_node = mirror_map.at(e.node.get());
NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc
index e08bd0d..497ea37 100644
--- a/src/operator/quantization/quantize_v2.cc
+++ b/src/operator/quantization/quantize_v2.cc
@@ -18,7 +18,7 @@
*/
/*!
- * \file quantize.cc
+ * \file quantize_v2.cc
* \brief
*/
diff --git a/src/operator/quantization/quantized_rnn-inl.h b/src/operator/quantization/quantized_rnn-inl.h
new file mode 100644
index 0000000..6ab53ce
--- /dev/null
+++ b/src/operator/quantization/quantized_rnn-inl.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_rnn-inl.h
+ * \brief Common functions for quantized recurrent neural network
+ * \author Zixuan Wei
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_
+
+namespace mxnet {
+namespace op {
+
+namespace quantized_rnn {
+enum QuantizedRnnInputs { kData, kParams, kState, kStateCell };
+enum QuantizedRnnInputMinMax { kDataScale, kDataShift };
+enum QuantizedRnnOutputs { kOut, kStateOut, kStateCellOut };
+} // namespace quantized_rnn
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_
diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc
new file mode 100644
index 0000000..88c80bc
--- /dev/null
+++ b/src/operator/quantization/quantized_rnn.cc
@@ -0,0 +1,363 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_rnn.cc
+ * \brief Common functions for quantized recurrent neural network
+ * \author Zixuan Wei
+ */
+
+#include <dmlc/logging.h>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "operator/rnn-inl.h"
+#include "operator/quantization/quantization_utils.h"
+#include "operator/quantization/quantized_rnn-inl.h"
+
+#if MXNET_USE_ONEDNN == 1
+#include "operator/quantization/dnnl/dnnl_quantized_rnn-inl.h"
+#endif
+
+namespace mxnet {
+namespace op {
+
+uint32_t QuantizedRnnNumInputs(const NodeAttrs& attrs) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm)
+ << "Quantized recurrent neural network only supports LSTM operator on "
+ "CPU.";
+ return 6U;
+}
+
+uint32_t QuantizedRnnNumOutputs(const NodeAttrs& attrs) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm)
+ << "Quantized recurrent neural network only supports LSTM operator on "
+ "CPU.";
+ return param.state_outputs ? 3U : 1U;
+}
+
+std::vector<std::string> QuantizedRnnInputNames(const NodeAttrs& attrs) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm)
+ << "Quantized recurrent neural network only supports LSTM operator on "
+ "CPU.";
+ return std::vector<std::string>{
+ "data", "parameters", "state", "state_cell", "min_data", "max_data"};
+}
+
+std::vector<std::string> QuantizedRnnOutputNames(const NodeAttrs& attrs) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm)
+ << "Quantized recurrent neural network only supports LSTM operator on "
+ "CPU.";
+ if (param.state_outputs) {
+ return std::vector<std::string>{"output", "state_output", "statecell_ouput"};
+ } else {
+ return std::vector<std::string>{"output"};
+ }
+}
+
+bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape>* in_shape,
+ std::vector<TShape>* out_shape) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode.";
+
+ const uint32_t num_inputs = QuantizedRnnNumInputs(attrs);
+ const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs);
+ CHECK_EQ(in_shape->size(), num_inputs)
+ << "Arguments' size of quantized RNN operator is mismatched. Expected " << num_inputs
+ << " argmuments but got " << in_shape->size() << ".";
+ CHECK_EQ(out_shape->size(), num_outputs);
+
+ const mxnet::TShape dshape = in_shape->at(quantized_rnn::kData);
+ if (!mxnet::ndim_is_known(dshape))
+ return false;
+ CHECK_EQ(dshape.ndim(), 3U) << "Input data of RNN operator should be 3-rank "
+ "tensor of dim [steps, batch, input size]";
+ const dim_t batch_size = dshape[1];
+ const dim_t input_size = dshape[2];
+ const dim_t directions = param.bidirectional ? 2 : 1;
+ const dim_t total_lyrs = directions * param.num_layers;
+ const dim_t state_size = param.state_size;
+ SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, state_size));
+ if (param.mode == rnn_enum::kLstm)
+ SHAPE_ASSIGN_CHECK(
+ *in_shape, quantized_rnn::kStateCell, Shape3(total_lyrs, batch_size, state_size));
+
+ const int param_size_fp = GetRnnParamSize(
+ param.num_layers, input_size, state_size, directions, param.mode, param.projection_size);
+ SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kParams, Shape1(param_size_fp));
+ const uint32_t num_base_inputs = GetRnnNumInputs(param);
+ for (size_t i = num_base_inputs; i < num_inputs; ++i)
+ SHAPE_ASSIGN_CHECK(*in_shape, i, Shape1(1));
+
+ out_shape->clear();
+ out_shape->push_back({dshape[0], batch_size, directions * state_size}); // output dim: [T, N, C]
+ if (param.state_outputs) {
+ out_shape->push_back({total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C]
+ if (param.mode == rnn_enum::kLstm)
+ out_shape->push_back({total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C]
+ }
+ return true;
+}
+
+bool QuantizedRnnType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_type,
+ std::vector<int>* out_type) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode.";
+
+ const uint32_t num_inputs = QuantizedRnnNumInputs(attrs);
+ const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs);
+ CHECK_EQ(in_type->size(), num_inputs);
+ CHECK_EQ(out_type->size(), num_outputs);
+
+ CHECK_EQ(in_type->at(quantized_rnn::kData), mshadow::kUint8)
+ << "Quantized RNN operator only supports uint8 input, while "
+ << in_type->at(quantized_rnn::kData) << " is given.";
+ TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kParams, mshadow::kFloat32);
+ TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kState, mshadow::kFloat32);
+ const uint32_t num_base_inputs = GetRnnNumInputs(param);
+ if (param.mode == rnn_enum::kLstm)
+ TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kStateCell, mshadow::kFloat32);
+ for (size_t i = num_base_inputs; i < num_inputs; ++i)
+ TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32);
+
+ TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kOut, mshadow::kFloat32);
+ if (param.state_outputs) {
+ TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateOut, mshadow::kFloat32);
+ if (param.mode == rnn_enum::kLstm)
+ TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateCellOut, mshadow::kFloat32);
+ }
+ return true;
+}
+
+bool QuantizedRnnStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ const uint32_t num_inputs = QuantizedRnnNumInputs(attrs);
+ const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs);
+ CHECK_EQ(in_attrs->size(), num_inputs);
+ CHECK_EQ(out_attrs->size(), num_outputs);
+
+#if MXNET_USE_ONEDNN == 1
+ return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+#else
+ *dispatch_mode = DispatchMode::kFCompute;
+
+ for (auto& v : *out_attrs) {
+ v = kDefaultStorage;
+ if (common::stype_string(v).compare("unknown") == 0) {
+ return false;
+ }
+ }
+
+ for (auto& v : *in_attrs) {
+ v = kDefaultStorage;
+ if (common::stype_string(v).compare("unknown") == 0) {
+ return false;
+ }
+ }
+ return true;
+#endif
+}
+
+void QuantizedRnnParamParser(nnvm::NodeAttrs* attrs) {
+ RNNParam param;
+ attrs->dict["quantized"] = "true";
+ try {
+ param.Init(attrs->dict, dmlc::parameter::kAllowUnknown);
+ } catch (const dmlc::ParamError& e) {
+ std::ostringstream os;
+ os << e.what();
+ os << ", in operator " << attrs->op->name << "("
+ << "name=\"" << attrs->name << "\"";
+ for (const auto& k : attrs->dict) {
+ os << ", " << k.first << "=\"" << k.second << "\"";
+ }
+ os << ")";
+ throw dmlc::ParamError(os.str());
+ }
+ attrs->parsed = std::move(param);
+}
+
+OpStatePtr CreateQuantizedRnnState(const nnvm::NodeAttrs& attrs,
+ const Context ctx,
+ const mxnet::ShapeVector& in_shapes,
+ const std::vector<int>& in_types) {
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode.";
+ OpStatePtr state = OpStatePtr();
+#if MXNET_USE_ONEDNN == 1
+ const int data_type = in_types[quantized_rnn::kData];
+ const int weight_type = in_types[quantized_rnn::kParams];
+ if (data_type == mshadow::kUint8 && weight_type == mshadow::kFloat32) {
+ const mxnet::TShape& data_shape = in_shapes[quantized_rnn::kData];
+ state =
+ OpStatePtr::Create<DNNLQuantizedRnnOp>(attrs, data_shape[0], data_shape[1], data_shape[2]);
+ }
+#else
+ LOG(FATAL) << "Quantized RNN operator relies on oneDNN library."
+ << " Please build MXNet with USE_ONEDNN=ON to leverage this operator.";
+#endif
+ return state;
+}
+
+void QuantizedRnnForwardCPU(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<TBlob>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& out_data) {
+ LOG(FATAL) << "Quantized RNN operator relies on oneDNN library."
+ << " Please build MXNet with USE_ONEDNN=ON to leverage this operator.";
+}
+
+#if MXNET_USE_ONEDNN == 1
+void QuantizedRnnForwardCPUEx(const OpStatePtr& state_ptr,
+ const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& out_data) {
+ DNNLQuantizedRnnOp& op = state_ptr.get_state<DNNLQuantizedRnnOp>();
+ op.Forward(ctx, in_data, req, out_data);
+}
+#endif // MXNET_USE_ONEDNN == 1
+
+bool NeedAsymQuantizeRnnInput(const NodeAttrs& attrs, const size_t index_to_check) {
+ bool need_asym_quantize = false;
+ switch (index_to_check) {
+ case rnn_enum::kData: {
+ need_asym_quantize = true;
+ break;
+ }
+ default: {
+ need_asym_quantize = false;
+ }
+ }
+ return need_asym_quantize;
+}
+
+bool AvoidRnnQuantizeInput(const NodeAttrs& attrs,
+ const size_t index_to_check,
+ const std::string quantize_granularity) {
+ std::unordered_set<size_t> avoid_indexes;
+ avoid_indexes.insert({quantized_rnn::kParams, quantized_rnn::kState, quantized_rnn::kStateCell});
+
+ return avoid_indexes.count(index_to_check);
+}
+
+bool AvoidRnnDequantizeOutput(const NodeAttrs& attrs, const size_t index_to_check) {
+ return true;
+}
+
+static std::vector<ResourceRequest> QuantizedRnnResourceEx(const NodeAttrs& attrs,
+ const int dev_mask,
+ const DispatchMode dispatch_mode) {
+ std::vector<ResourceRequest> request;
+ if (dev_mask == kGPU) {
+#if MXNET_USE_CUDNN == 1
+ LOG(FATAL) << "Currently, quantized RNN is not supported on the GPU platform.";
+#endif
+ } else {
+#if MXNET_USE_ONEDNN == 1
+ request.emplace_back(ResourceRequest::kTempSpace);
+#endif
+ }
+ return request;
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_rnn)
+ .add_alias("_npx_contrib_quantized_rnn")
+ .describe(R"code(RNN operator for input data type of uint8. The weight of each
+gates is converted to int8, while bias is accumulated in type float32.
+The hidden state and cell state are in type float32. For the input data, two more arguments
+of type float32 must be provided representing the thresholds of quantizing argument from
+data type float32 to uint8. The final outputs contain the recurrent result in float32.
+It only supports quantization for Vanilla LSTM network.
+
+.. Note::
+ This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE)
+ .set_num_inputs(QuantizedRnnNumInputs)
+ .set_num_outputs(QuantizedRnnNumOutputs)
+ .set_attr_parser(QuantizedRnnParamParser)
+ .set_attr<nnvm::FListInputNames>("FListInputNames", QuantizedRnnInputNames)
+ .set_attr<nnvm::FListOutputNames>("FListOutputNames", QuantizedRnnOutputNames)
+ .set_attr<mxnet::FInferShape>("FInferShape", QuantizedRnnShape)
+ .set_attr<nnvm::FInferType>("FInferType", QuantizedRnnType)
+ .set_attr<FInferStorageType>("FInferStorageType", QuantizedRnnStorageType)
+ .set_attr<FCreateOpState>("FCreateOpState", CreateQuantizedRnnState)
+ .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", QuantizedRnnForwardCPU)
+#if MXNET_USE_ONEDNN == 1
+ .set_attr<bool>("TIsDNNL", true)
+ .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", QuantizedRnnForwardCPUEx)
+#endif
+ .set_attr<FResourceRequestEx>("FResourceRequestEx", QuantizedRnnResourceEx)
+ .add_argument("data", "NDArray-or-Symbol", "Input data.")
+ .add_argument("parameters", "NDArray-or-Symbol", "weight.")
+ .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN")
+ .add_argument("state_cell",
+ "NDArray-or-Symbol",
+ "initial cell state for LSTM networks (only for LSTM)")
+ .add_argument("data_scale", "NDArray-or-Symbol", "quantization scale of data.")
+ .add_argument("data_shift", "NDArray-or-Symbol", "quantization shift of data.")
+ .add_arguments(RNNParam::__FIELDS__());
+
+NNVM_REGISTER_OP(RNN)
+ .set_attr<FQuantizable>("FQuantizable",
+ [](const NodeAttrs& attrs) {
+#if MXNET_USE_ONEDNN == 1
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ if (param.mode != rnn_enum::kLstm)
+ LOG(INFO) << "Quantized RNN only supports LSTM mode.";
+ if (param.mode == rnn_enum::kLstm &&
+ !param.projection_size.has_value()) {
+ return QuantizeType::kMust;
+ } else {
+ return QuantizeType::kNone;
+ }
+#else
+ LOG(INFO) << "Quantized RNN is not supported by this MXNet release. Please enable oneDNN to "
+ << "use the feature.";
+ return QuantizeType::kNone;
+#endif // MXNET_USE_ONEDNN == 1
+ })
+ .set_attr<FQuantizedOp>("FQuantizedOp",
+ [](const NodeAttrs& attrs) {
+ nnvm::ObjectPtr node = nnvm::Node::Create();
+ node->attrs.op = Op::Get("_contrib_quantized_rnn");
+ node->attrs.name = "quantized_" + attrs.name;
+ node->attrs.dict = attrs.dict;
+ node->attrs.dict["quantized"] = "true";
+ if (node->op()->attr_parser != nullptr) {
+ node->op()->attr_parser(&(node->attrs));
+ }
+ return node;
+ })
+ .set_attr<FNeedAsymQuantizeInput>("FNeedAsymQuantizeInput", NeedAsymQuantizeRnnInput)
+ .set_attr<FAvoidQuantizeInput>("FAvoidQuantizeInput", AvoidRnnQuantizeInput)
+ .set_attr<FAvoidDequantizeOutput>("FAvoidDequantizeOutput", AvoidRnnDequantizeOutput);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index c348554..eac274f 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -291,9 +291,9 @@
return size;
}
-inline size_t GetNumInputArguments(RNNParam param_) {
- size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U;
- if (param_.use_sequence_length)
+inline size_t GetRnnNumInputs(RNNParam param) {
+ size_t num_inputs = (param.mode == rnn_enum::kLstm) ? 4U : 3U;
+ if (param.use_sequence_length)
num_inputs += 1U;
return num_inputs;
}
@@ -748,7 +748,7 @@
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
- size_t num_inputs = GetNumInputArguments(param_);
+ size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
@@ -1125,7 +1125,7 @@
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
- size_t num_inputs = GetNumInputArguments(param_);
+ size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
@@ -1369,7 +1369,7 @@
const std::vector<TBlob>& out_data) {
using namespace mshadow;
- size_t num_inputs = GetNumInputArguments(param_);
+ size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index e4b84dd..5a03b06 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -34,31 +34,41 @@
namespace op {
DMLC_REGISTER_PARAMETER(RNNParam);
-static inline std::vector<std::string> ListArguments(const RNNParam& param_) {
+static inline std::vector<std::string> ListRnnInputNames(const RNNParam& param) {
// All RNNs start off with same 3 input arguments
std::vector<std::string> arguments{"data", "parameters", "state"};
// LSTMs also have an additional state_cell argument
- if (param_.mode == rnn_enum::kLstm) {
+ if (param.mode == rnn_enum::kLstm) {
arguments.emplace_back("state_cell");
}
// All RNNs have option of additional sequence_length argument
- if (param_.use_sequence_length) {
+ if (param.use_sequence_length) {
arguments.emplace_back("sequence_length");
}
return arguments;
}
+static inline std::vector<std::string> ListRnnOutputNames(const RNNParam& param) {
+ std::vector<std::string> names{"output"};
+ if (param.state_outputs) {
+ names.emplace_back("state_output");
+ if (param.mode == rnn_enum::kLstm)
+ names.emplace_back("statecell_output");
+ }
+ return names;
+}
+
static bool RNNShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
- const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
using namespace mshadow;
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
- // Query param_ object to figure out what the expectd input arguments are
- std::vector<std::string> expected_arguments = ListArguments(param_);
+ // Query param object to figure out what the expectd input arguments are
+ std::vector<std::string> expected_arguments = ListRnnInputNames(param);
CHECK_EQ(in_shape->size(), expected_arguments.size())
<< "Input shape mismatch. Expected " << expected_arguments.size()
@@ -76,29 +86,29 @@
}
int batch_size = dshape[1];
int input_size = dshape[2];
- int numDirections = param_.bidirectional ? 2 : 1;
- int total_layers = numDirections * param_.num_layers; // double for bidirectional
+ int numDirections = param.bidirectional ? 2 : 1;
+ int total_layers = numDirections * param.num_layers; // double for bidirectional
int layer_size =
- (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size;
+ (param.projection_size.has_value()) ? param.projection_size.value() : param.state_size;
SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, Shape3(total_layers, batch_size, layer_size));
- if (param_.mode == rnn_enum::kLstm) {
+ if (param.mode == rnn_enum::kLstm) {
SHAPE_ASSIGN_CHECK(
- *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param_.state_size));
+ *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param.state_size));
}
// calculate parameter vector length
- int param_size = GetRnnParamSize(param_.num_layers,
+ int param_size = GetRnnParamSize(param.num_layers,
input_size,
- param_.state_size,
+ param.state_size,
numDirections,
- param_.mode,
- param_.projection_size);
+ param.mode,
+ param.projection_size);
SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
// Check on sequence_length shape if using
- if (param_.use_sequence_length) {
+ if (param.use_sequence_length) {
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
- if (param_.mode != rnn_enum::kLstm)
+ if (param.mode != rnn_enum::kLstm)
--seq_len_input_idx;
SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size));
@@ -107,29 +117,29 @@
out_shape->clear();
// output: [sequence len, batch, output size]
TShape oshape = dshape;
- if (param_.projection_size.has_value()) {
- oshape[2] = numDirections * param_.projection_size.value();
+ if (param.projection_size.has_value()) {
+ oshape[2] = numDirections * param.projection_size.value();
} else {
- oshape[2] = numDirections * param_.state_size;
+ oshape[2] = numDirections * param.state_size;
}
out_shape->push_back(oshape);
- if (param_.state_outputs) {
+ if (param.state_outputs) {
// outStateShape: [layer_num, batch, state size]
TShape outStateShape = dshape;
outStateShape[0] = total_layers;
outStateShape[1] = batch_size;
- if (param_.projection_size.has_value()) {
- outStateShape[2] = param_.projection_size.value();
+ if (param.projection_size.has_value()) {
+ outStateShape[2] = param.projection_size.value();
} else {
- outStateShape[2] = param_.state_size;
+ outStateShape[2] = param.state_size;
}
out_shape->push_back(outStateShape);
// Deal with lstm cell state
- if (param_.mode == rnn_enum::kLstm) {
+ if (param.mode == rnn_enum::kLstm) {
TShape cellStateShape = dshape;
cellStateShape[0] = total_layers;
cellStateShape[1] = batch_size;
- cellStateShape[2] = param_.state_size;
+ cellStateShape[2] = param.state_size;
out_shape->push_back(cellStateShape);
}
}
@@ -140,34 +150,34 @@
static bool RNNType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
- const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
- CHECK_EQ(in_type->size(), GetNumInputArguments(param_));
+ CHECK_EQ(in_type->size(), GetRnnNumInputs(param));
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
- if (param_.mode != rnn_enum::kLstm)
+ if (param.mode != rnn_enum::kLstm)
--seq_len_input_idx;
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
- std::vector<std::string> arguments = ListArguments(param_);
+ std::vector<std::string> arguments = ListRnnInputNames(param);
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
} else {
// If using sequence length argument, it has its own indexing type
// All other input arguments must match the main data type
- if (!(param_.use_sequence_length && i == seq_len_input_idx)) {
+ if (!(param.use_sequence_length && i == seq_len_input_idx)) {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]);
}
}
}
out_type->clear();
out_type->push_back(dtype);
- if (param_.state_outputs) {
+ if (param.state_outputs) {
out_type->push_back(dtype);
// Deal with lstm cell state
- if (param_.mode == rnn_enum::kLstm) {
+ if (param.mode == rnn_enum::kLstm) {
out_type->push_back(dtype);
}
}
@@ -248,7 +258,7 @@
#if MXNET_USE_ONEDNN == 1
if (ctx.dev_type == kCPU && SupportDNNLRnn(param, in_types[rnn_enum::kData])) {
const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
- state = OpStatePtr::Create<DNNLRnnOp>(param, data_shape[0], data_shape[1], data_shape[2]);
+ state = OpStatePtr::Create<DNNLRnnOp>(attrs, data_shape[0], data_shape[1], data_shape[2]);
return state;
}
#endif // MXNET_USE_ONEDNN == 1
@@ -370,7 +380,7 @@
.set_attr_parser(ParamParser<RNNParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
- return GetNumInputArguments(params);
+ return GetRnnNumInputs(params);
})
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
@@ -386,18 +396,12 @@
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
- return ListArguments(params);
+ return ListRnnInputNames(params);
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
- std::vector<std::string> names{"output"};
- if (params.state_outputs) {
- names.emplace_back("state_output");
- if (params.mode == rnn_enum::kLstm)
- names.emplace_back("statecell_output");
- }
- return names;
+ return ListRnnOutputNames(params);
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
@@ -441,7 +445,7 @@
})
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
- return GetNumInputArguments(params);
+ return GetRnnNumInputs(params);
})
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index dcd4bbd..6b74a49 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -1414,3 +1414,106 @@
assert 'layer1' in min_max_dict
assert_almost_equal(onp.array([min_max_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4)
+
+@use_np
+def test_rnn_quantization():
+ data_low = -1
+ data_high = 1
+ def check_rnn_quantization(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size):
+ data_shape = (seq_len, batch_size, input_dim)
+
+ rnn_fp32 = mx.gluon.rnn.LSTM(hidden_size=state_size,
+ num_layers = num_layers,
+ bidirectional=bidirectional)
+
+ data = mx.np.random.uniform(low=data_low, high=data_high, size=data_shape)
+ states_shape = (num_layers * 2 if bidirectional else num_layers, batch_size, state_size)
+ states = [mx.np.zeros((states_shape)) for _ in range(batch_size)]
+
+ rnn_fp32.initialize()
+ rnn_fp32.hybridize()
+ ref_out = rnn_fp32(data, states)
+
+ class RNNDataLoader(mx.gluon.data.DataLoader):
+ def __init__(self, data, states):
+ super().__init__(mx.gluon.data.SimpleDataset([]), 1)
+ self.data = data
+ self.states = states
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return [self.data, self.states]
+
+ def __bool__(self):
+ return bool(self.dataiter.iter_next())
+
+ calib_data = RNNDataLoader(data, states)
+ quant_rnn = mx.contrib.quant.quantize_net(rnn_fp32,
+ quantized_dtype='auto',
+ quantize_mode='full',
+ calib_data=calib_data,
+ calib_mode='naive',
+ num_calib_batches=1,
+ device=mx.current_device())
+ qout = quant_rnn(data, states)
+
+ qsym, _ = quant_rnn.export(None)
+ assert qsym.tojson().find("quantized_rnn") != -1
+
+ ref_out = [ref_out[0], ref_out[1][0], ref_out[1][1]]
+ for i in range(len(qout)):
+ mse = onp.mean((ref_out[i].asnumpy() - qout[i].asnumpy())**2)
+ assert mse < 0.001
+
+ check_rnn_quantization(1, False, 5, 2, 16, 16)
+ check_rnn_quantization(1, True, 5, 2, 16, 16)
+
+
+
+@use_np
+def test_quantized_rnn():
+ def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size):
+ ndir = 2 if bidirectional else 1
+ size = ndir*state_size*4
+ first_lyr_param_size = (input_dim + state_size + 2) * size
+ other_lyr_param_size = (state_size * ndir + state_size + 2) * size
+ full_param_size = first_lyr_param_size + (num_layers - 1) * other_lyr_param_size
+
+ data = mx.np.random.uniform(-1, 1, (seq_len, batch_size, input_dim))
+ state = mx.np.random.uniform(-1, 1, (num_layers*ndir, batch_size, state_size))
+ state_cell = mx.np.random.uniform(0, 1, (num_layers*ndir, batch_size, state_size))
+ params = mx.np.random.normal(0, 1, (full_param_size,))
+
+ out = npx.rnn(data=data,
+ parameters=params,
+ mode='lstm',
+ state=state,
+ state_size=state_size,
+ state_cell=state_cell,
+ num_layers=num_layers,
+ bidirectional=bidirectional)
+
+ data_min = mx.np.min(data)
+ data_max = mx.np.max(data)
+ data_scale = mx.np.array(128.0 / (data_max - data_min)).reshape((1,))
+ data_shift = mx.np.array(128.0 - data_max * data_scale).reshape((1,))
+
+ qdata = (data * data_scale + data_shift + 0.5).astype('uint8')
+ qout = npx.contrib_quantized_rnn(data=qdata,
+ parameters=params,
+ mode='lstm',
+ state=state,
+ state_size=state_size,
+ state_cell=state_cell,
+ num_layers=num_layers,
+ bidirectional=bidirectional,
+ data_scale=data_scale,
+ data_shift=data_shift)
+
+ mse = onp.mean((out.asnumpy() - qout.asnumpy())**2)
+ assert mse < 0.001
+
+ check_quantized_rnn(1, False, 5, 2, 16, 16)
+ check_quantized_rnn(1, True, 5, 2, 16, 16)
\ No newline at end of file