Reduce after quantization memory usage (#20894)
diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py
index 4ad354a..10d2455 100644
--- a/python/mxnet/contrib/quantization.py
+++ b/python/mxnet/contrib/quantization.py
@@ -921,6 +921,9 @@
if calib_mode in ['naive', 'entropy', 'custom']:
inputs = [mx.sym.var(desc.name) for desc in data_descs]
calib_net = SymbolBlock(symnet, inputs)
+ for k, v in calib_net.collect_params().items():
+ v.grad_req = 'null'
+
calib_net.load_dict(params, cast_dtype=True, dtype_source='saved')
calib_net.hybridize(static_alloc=False, static_shape=False)
num_batches = _collect_layer_statistics(calib_net, calib_data, collector, num_inputs,
@@ -939,6 +942,9 @@
inputs = [mx.sym.var(desc.name) for desc in data_descs]
net = SymbolBlock(qsym, inputs)
+ for k, v in net.collect_params().items():
+ v.grad_req = 'null'
+
all_params = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in qarg_params.items()}
all_params.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()})
net.load_dict(all_params, cast_dtype=True, dtype_source='saved')