| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import sys |
| import os |
| import mxnet as mx |
| import numpy as np |
| import unittest |
| import ctypes |
| from mxnet.io import NDArrayIter |
| from mxnet.module import Module |
| from mxnet.symbol import Symbol |
| from importlib import import_module |
| from numpy.testing import assert_allclose |
| from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str |
| from mxnet.test_utils import DummyIter |
| curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) |
| sys.path.append(os.path.join(curr_path, '../unittest/')) |
| from common import with_seed |
| from mxnet.test_utils import assert_almost_equal |
| import itertools |
| |
| OP_NAME='op_name' |
| QUANTIZED_OP_NAME='quantized_op_name' |
| SG_PASS_NAME='sg_pass_name' |
| POST_SG_PASS_NAME='post_sg_pass_name' |
| config = { |
| 'conv': { |
| OP_NAME: 'sg_mkldnn_conv', |
| QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv', |
| SG_PASS_NAME: 'MKLDNN', |
| POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE' |
| }, |
| 'fc': { |
| OP_NAME: 'sg_mkldnn_fully_connected', |
| QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected', |
| SG_PASS_NAME: 'MKLDNN', |
| POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE' |
| } |
| } |
| |
| DATA_SHAPE=[(4, 4, 10, 10), (32, 3, 24, 24), (64, 8, 64, 64)] |
| |
| def check_qsym_calibrated(qsym, out_type, name='conv'): |
| quantized_op_name = config[name][QUANTIZED_OP_NAME] |
| assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 |
| for k, v in qsym.attr_dict().items(): |
| if k.find('_quantize') != -1: |
| assert v['out_type'] == out_type |
| if k.find(quantized_op_name) != -1: |
| if name == 'fc' and 'enable_float_output' in v: |
| continue |
| assert 'min_calib_range' in v |
| assert 'max_calib_range' in v |
| |
| def check_qsym_scale_align(qsym): |
| assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 |
| init = False |
| for k, v in qsym.attr_dict().items(): |
| if k.find('quantized_sg_mkldnn_conv') != -1: |
| assert 'min_calib_range' in v |
| assert 'max_calib_range' in v |
| if not init: |
| min_calib_range = v['min_calib_range'] |
| max_calib_range = v['max_calib_range'] |
| init = True |
| else: |
| assert min_calib_range == v['min_calib_range'] |
| assert max_calib_range == v['max_calib_range'] |
| |
| |
| |
| def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): |
| mod = Module(symbol=qsym, context=mx.current_context()) |
| mod.bind(for_training=False, |
| data_shapes=[('data', data_shape)], |
| label_shapes=[('softmax_label', label_shape)]) |
| mod.set_params(qarg_params, qaux_params) |
| mod.forward(batch, is_train=False) |
| for output in mod.get_outputs(): |
| output.wait_to_read() |
| return mod.get_outputs() |
| |
| def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape): |
| mod = Module(symbol=qsym, context=mx.current_context()) |
| mod.bind(for_training=False, |
| data_shapes=[('data', data_shape)], |
| label_shapes=[('softmax_label', label_shape)]) |
| mod.init_params(initializer=mx.init.Xavier(magnitude=2.)) |
| mod.forward(batch, is_train=False) |
| for output in mod.get_outputs(): |
| output.wait_to_read() |
| return mod.get_outputs() |
| |
| def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape): |
| # save qsym to JSON file |
| qsym.save('quantized-symbol.json') |
| # save params |
| save_dict = {('arg:%s' % k): v.as_in_context(mx.current_context()) for k, v in qarg_params.items()} |
| save_dict.update({('aux:%s' % k): v.as_in_context(mx.current_context()) for k, v in qaux_params.items()}) |
| mx.nd.save('quantized-0000.params', save_dict) |
| # load back with SymbolBlock |
| net = mx.gluon.SymbolBlock.imports('quantized-symbol.json', ['data'], 'quantized-0000.params') |
| net.collect_params().reset_ctx(ctx = mx.current_context()) |
| net.hybridize() |
| |
| data = mx.random.uniform(-1.0, 1.0, shape=data_shape) |
| net(data) |
| |
| def check_quantize(sym, data_shape, out_type, name='conv', |
| check_calibration=True, gluon_forward=False, check_scale_align=False): |
| sg_pass_name = config[name][SG_PASS_NAME] |
| post_sg_pass_name = config[name][POST_SG_PASS_NAME] |
| |
| fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc_softmax') |
| if gluon_forward == True: |
| sym = fc |
| sym_sg = sym.get_backend_symbol(sg_pass_name) |
| mod = Module(symbol=sym, label_names=[]) |
| mod.bind(for_training=False, |
| data_shapes=[('data', data_shape)]) |
| else: |
| sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') |
| sym_sg = sym.get_backend_symbol(sg_pass_name) |
| label_shape = (data_shape[0], 10) |
| mod = Module(symbol=sym) |
| mod.bind(for_training=False, |
| data_shapes=[('data', data_shape)], |
| label_shapes=[('softmax_label', label_shape)]) |
| mod.init_params(mx.init.Normal(0.5)) |
| arg_params, aux_params = mod.get_params() |
| |
| data = [mx.random.uniform(-1, 1, shape=shape, ctx=mx.current_context()) for _, shape in mod.data_shapes] |
| batch = mx.io.DataBatch(data, []) |
| |
| mod.forward(batch, is_train=False) |
| for output in mod.get_outputs(): |
| output.wait_to_read() |
| ref_out = mod.get_outputs() |
| |
| excluded_sym_names = [] |
| if mx.current_context() == mx.cpu() and gluon_forward == True: |
| excluded_sym_names += ['sg_mkldnn_fully_connected_0'] |
| excluded_sym_names += ['fc_softmax'] |
| |
| calib_data = mx.nd.random.uniform(shape=data_shape) |
| calib_data = NDArrayIter(data=calib_data) |
| calib_data = DummyIter(calib_data) |
| calib_layer = lambda name: name.endswith('_output') |
| qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, |
| arg_params=arg_params, |
| aux_params=aux_params, |
| ctx=mx.current_context(), |
| excluded_sym_names=excluded_sym_names, |
| quantized_dtype=out_type, |
| calib_mode='naive', |
| calib_data=calib_data, |
| calib_layer=calib_layer, |
| num_calib_examples=5) |
| qsym = qsym.get_backend_symbol(post_sg_pass_name) |
| if check_calibration: |
| check_qsym_calibrated(qsym, out_type, name=name) |
| if check_scale_align: |
| check_qsym_scale_align(qsym) |
| if gluon_forward == True: |
| check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) |
| else: |
| check_qsym_dummy_forward(qsym, batch, data_shape, label_shape) |
| quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) |
| for i in range(len(ref_out)): |
| assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) |
| |
| @with_seed() |
| def check_quantize_whole_model_with_forward(): |
| def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape): |
| mod = Module(symbol=qsym, label_names=None, context=mx.current_context()) |
| mod.bind(for_training=False, |
| data_shapes=[('data', data_shape)]) |
| mod.set_params(qarg_params, qaux_params) |
| data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] |
| batch = mx.io.DataBatch(data, []) |
| mod.forward(batch, is_train=False) |
| for output in mod.get_outputs(): |
| output.wait_to_read() |
| |
| def check_quantize_whole_model(out_type): |
| batch_size = 4 |
| data_shape = (batch_size, 4, 10, 10) |
| data = mx.sym.Variable('data') |
| conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') |
| sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1') |
| sym_sg = sym.get_backend_symbol('MKLDNN') |
| mod = Module(symbol=sym, label_names=[]) |
| mod.bind(for_training=False, |
| data_shapes=[('data', data_shape)]) |
| |
| mod.init_params(mx.init.Normal(0.5)) |
| arg_params, aux_params = mod.get_params() |
| |
| excluded_sym_names = [] |
| |
| calib_data = mx.nd.random.uniform(shape=data_shape) |
| calib_data = NDArrayIter(data=calib_data) |
| calib_data = DummyIter(calib_data) |
| calib_layer = lambda name: name.endswith('_output') |
| qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, |
| arg_params=arg_params, |
| aux_params=aux_params, |
| ctx=mx.current_context(), |
| excluded_sym_names=excluded_sym_names, |
| quantized_dtype=out_type, |
| calib_mode='naive', |
| calib_data=calib_data, |
| calib_layer=calib_layer, |
| num_calib_examples=5) |
| qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') |
| check_qsym_forward(qsym, qarg_params, qaux_params, data_shape) |
| |
| for qdtype in ['uint8', 'int8', 'auto']: |
| check_quantize_whole_model(qdtype) |
| |
| @with_seed() |
| def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True): |
| op_name = config[name][OP_NAME] |
| sg_pass_name = config[name][SG_PASS_NAME] |
| |
| sym_sg = sym.get_backend_symbol(sg_pass_name) |
| assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 |
| for k, v in sym_sg.attr_dict().items(): |
| if k.find(op_name) != -1: |
| for attr_op in attrs_op: |
| assert v[attr_op] in ['true', 'True'] |
| |
| arg_shapes, _, aux_shapes = sym.infer_shape() |
| arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes] |
| aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] |
| exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') |
| exe.forward() |
| os.environ['MXNET_SUBGRAPH_BACKEND'] = sg_pass_name |
| exe_sg = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') |
| exe_sg.forward() |
| del os.environ['MXNET_SUBGRAPH_BACKEND'] |
| for i in range(len(exe.outputs)): |
| assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) |
| |
| # fp32 to int8 |
| out_type_list = ['uint8', 'int8', 'auto'] |
| |
| if check_quantization: |
| for out_type in out_type_list: |
| check_quantize(sym, data_shape, out_type, name=name) |
| # TODO(ciyong), since quantized fc save its params in int8, while gluon treat the default |
| # variable from symbol file as fp32 which results in mismatch dtype of params. |
| # Skip quantized fc in gluon pass. |
| if name != 'fc': |
| check_quantize(sym, data_shape, out_type, name=name, gluon_forward=True) |
| |
| def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, |
| date_shape=(4,4,10,10), name='conv'): |
| op_name = config[name][OP_NAME] |
| sg_pass_name = config[name][SG_PASS_NAME] |
| |
| for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): |
| sym_sg = sym.get_backend_symbol(sg_pass_name) |
| exe_sg = sym_sg.simple_bind(mx.cpu(), data=date_shape, grad_req='null') |
| |
| attrs_dict = sym_sg.attr_dict() |
| for k, v in attrs_dict.items(): |
| if k.find(op_name) != -1: |
| for attr in attrs: |
| assert v[attr] == 'true' |
| for exc_attr in excluded_attr: |
| assert exc_attr not in v.keys() |
| |
| def head_symbol(data_shape): |
| data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') |
| weight = mx.symbol.Variable('weight', dtype='float32') |
| bn = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=0.9, name='bn') |
| return bn, weight |
| |
| # single conv fuision case |
| def single_conv(no_bias, data_shape): |
| conv_attr = [''] |
| data, weight = head_symbol(data_shape) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| return conv, conv_attr |
| |
| # conv + bn fusion case |
| def conv_bn(no_bias, data_shape): |
| conv_bn_attr = ['with_bn'] |
| data, weight = head_symbol(data_shape) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") |
| return bn1, conv_bn_attr |
| |
| # conv + relu fusion case |
| def conv_relu(no_bias, data_shape): |
| conv_relu_attr = ['with_relu'] |
| data, weight = head_symbol(data_shape) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") |
| return relu, conv_relu_attr |
| |
| # conv + add fusion case |
| def conv_add(no_bias, data_shape): |
| conv_add_attr = ['with_sum'] |
| data, weight = head_symbol(data_shape) |
| conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, |
| kernel=(3, 3), stride=(1, 1)) |
| pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') |
| sum = conv1 + pool |
| return sum, conv_add_attr |
| |
| # conv + add fusion case 2 |
| def conv_add2(no_bias, data_shape): |
| conv_add_attr = ['with_sum'] |
| data, weight = head_symbol(data_shape) |
| conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, |
| kernel=(3, 3), stride=(1, 1)) |
| pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') |
| sum = pool + conv1 |
| return sum, conv_add_attr |
| |
| # conv + bn + relu fusion case |
| def conv_bn_relu(no_bias, data_shape): |
| conv_bn_relu_attr = ['with_bn', 'with_relu'] |
| data, weight = head_symbol(data_shape) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") |
| relu = mx.symbol.Activation(data=bn1, name='relu', act_type="relu") |
| return relu, conv_bn_relu_attr |
| |
| # conv + bn + add + relu fusion case |
| def conv_bn_sum_relu(no_bias, data_shape): |
| conv_bn_add_relu_attr = ['with_sum', 'with_postsum_relu', 'with_bn'] |
| data, weight = head_symbol(data_shape) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=no_bias) |
| bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") |
| conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, |
| kernel=(3, 3), stride=(1, 1)) |
| sum1 = bn1 + conv1 |
| relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu") |
| return relu, conv_bn_add_relu_attr |
| |
| # single concat case |
| def single_concat(data_shape, input_num, dim): |
| data, weight = head_symbol(data_shape) |
| inputs = [] |
| for i in range(input_num): |
| inputs.append(data) |
| concat = mx.symbol.Concat(*inputs, name="concat", dim=dim) |
| return concat |
| |
| # concat scale alignment case |
| def concat_scale_align(data_shape): |
| data, weight = head_symbol(data_shape) |
| conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=True) |
| conv2 = mx.symbol.Convolution(data=data, weight=weight * 2, name='conv2', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=True) |
| conv3 = mx.symbol.Convolution(data=data, weight=weight * 3, name='conv3', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=True) |
| conv4 = mx.symbol.Convolution(data=data, weight=weight * 4, name='conv4', num_filter=64, |
| kernel=(3, 3), stride=(1, 1), no_bias=True) |
| concat = mx.symbol.Concat(*[conv1, conv2, conv3, conv4], name="concat", dim=1) |
| return concat |
| |
| def tail_neg_symbol(sym1, sym2): |
| fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') |
| fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2') |
| concat = mx.sym.Concat(*[fc1, fc2], name="concat") |
| sym = mx.sym.SoftmaxOutput(data=concat, name='softmax') |
| return sym |
| |
| # conv + bn can't be fusion case |
| # eg.1 |
| # conv --------- > bn |
| # | |
| # | |
| # -------------> [custom op] |
| def neg_conv_bn(data_shape): |
| syms = [] |
| attrs = [] |
| excluded_attrs = [] |
| data, weight = head_symbol(data_shape) |
| |
| # eg.1 ([custom op] = pool) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") |
| pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') |
| sym = tail_neg_symbol(bn1, pool) |
| |
| syms.append(sym) |
| attrs.append([]) |
| excluded_attrs.append([]) |
| return syms, attrs, excluded_attrs |
| |
| # conv + relu can't be fusion case |
| # eg.1 |
| # conv -----------> relu |
| # | |
| # | |
| # ---------------> [custom op] |
| def neg_conv_relu(data_shape): |
| syms = [] |
| attrs = [] |
| excluded_attrs = [] |
| data, weight = head_symbol(data_shape) |
| |
| # eg.1 ([custom op] = pool) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") |
| pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') |
| sym = tail_neg_symbol(relu, pool) |
| |
| syms.append(sym) |
| attrs.append([]) |
| excluded_attrs.append([]) |
| return syms, attrs, excluded_attrs |
| |
| # conv + add can't be fusion case |
| # eg.1 |
| # ---------------> [custom op] |
| # | |
| # | |
| # conv -----------> add |
| # | |
| # | |
| # added ------------> |
| def neg_conv_add(data_shape): |
| syms = [] |
| attrs = [] |
| excluded_attrs = [] |
| val = mx.symbol.Variable('addval') |
| data, weight = head_symbol(data_shape) |
| |
| # eg.1 ([custom op] = pool, [added op] = val) |
| conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| sum1 = conv + val |
| pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') |
| sym = tail_neg_symbol(sum1, pool) |
| |
| syms.append(sym) |
| attrs.append([]) |
| excluded_attrs.append('with_sum') |
| return syms, attrs, excluded_attrs |
| |
| # conv + bn + relu can't be fusion case |
| # eg.1 |
| # --------------> [custom op] |
| # | |
| # conv -----------> bn -----------> relu |
| # |
| # eg.2 |
| # --------------> [custom op] |
| # | |
| # conv -----------> bn -----------> relu |
| def neg_conv_bn_relu(data_shape): |
| syms = [] |
| attrs = [] |
| excluded_attrs = [] |
| data, weight = head_symbol(data_shape) |
| |
| # eg.1 ([custom op] = pool11) |
| conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") |
| relu11 = mx.symbol.Activation(data=bn11, name='relu11', act_type="relu") |
| pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') |
| sym1 = tail_neg_symbol(relu11, pool11) |
| |
| syms.append(sym1) |
| attrs.append([]) |
| excluded_attrs.append([]) |
| |
| # eg.2 ([custom op] = pool) |
| conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") |
| relu21 = mx.symbol.Activation(data=bn21, name='relu21', act_type="relu") |
| pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') |
| sym2 = tail_neg_symbol(relu21, pool21) |
| |
| syms.append(sym2) |
| attrs.append(['with_bn']) |
| excluded_attrs.append(['with_relu']) |
| return syms, attrs, excluded_attrs |
| |
| # conv + bn + add + relu can't be fusion case |
| # eg.1 |
| # --------------> [custom op] |
| # | |
| # conv -----------> bn -----------> add -----------> relu |
| # |
| # eg.2 |
| # -------------> [custom op] |
| # | |
| # conv -----------> bn -----------> add -----------> relu |
| # |
| # eg.3 |
| # --------------> [custom op] |
| # | |
| # conv -----------> bn -----------> add -----------> relu |
| def neg_conv_bn_add_relu(data_shape): |
| syms = [] |
| attrs = [] |
| excluded_attrs = [] |
| addVal = mx.symbol.Variable('addval') |
| data, weight = head_symbol(data_shape) |
| |
| # eg.1 |
| conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") |
| sum11 = bn11 + addVal |
| relu11 = mx.symbol.Activation(data=sum11, name='relu11', act_type="relu") |
| pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') |
| sym1 = tail_neg_symbol(relu11, pool11) |
| |
| syms.append(sym1) |
| attrs.append([]) |
| excluded_attrs.append(['with_sum', 'with_postsum_relu', 'with_bn']) |
| |
| # eg.2 |
| conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") |
| sum21 = bn21 + addVal |
| relu21 = mx.symbol.Activation(data=sum21, name='relu21', act_type="relu") |
| pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') |
| sym2 = tail_neg_symbol(relu21, pool21) |
| |
| syms.append(sym2) |
| attrs.append(['with_bn']) |
| excluded_attrs.append(['with_sum', 'with_postsum_relu']) |
| |
| # eg.3 |
| conv31 = mx.symbol.Convolution(data=data, weight=weight, name='conv31', num_filter=64, kernel=(3, 3), stride=(1, 1)) |
| bn31 = mx.symbol.BatchNorm(data=conv31, name="bn31") |
| sum31 = bn31 + addVal |
| relu31 = mx.symbol.Activation(data=sum31, name='relu31', act_type="relu") |
| pool31 = mx.sym.Pooling(data=sum31, kernel=(4, 4), pool_type='avg', name='pool31') |
| sym3 = tail_neg_symbol(relu31, pool31) |
| |
| syms.append(sym3) |
| attrs.append(['with_bn', 'with_sum']) |
| excluded_attrs.append(['with_postsum_relu']) |
| return syms, attrs, excluded_attrs |
| |
| def single_fc(no_bias, data_shape, flatten=True): |
| attr = [''] |
| data, weight = head_symbol(data_shape) |
| fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, |
| no_bias=no_bias, flatten=flatten) |
| return fc, attr |
| |
| def fc_relu(no_bias, data_shape, flatten=True): |
| attr = ['with_relu'] |
| data, weight = head_symbol(data_shape) |
| fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, |
| no_bias=no_bias, flatten=flatten) |
| relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu") |
| return relu, attr |
| |
| # fc + relu can't be fusion case |
| # eg.1 |
| # fc -----------> relu |
| # | |
| # | |
| # ---------------> [custom op] |
| def neg_fc_relu(no_bias, data_shape, flatten=True): |
| syms = [] |
| attrs = [] |
| excluded_attrs = [] |
| data, weight = head_symbol(data_shape) |
| |
| # eg.1 ([custom op] = pool) |
| fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, |
| no_bias=no_bias, flatten=flatten) |
| relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu") |
| sigmoid = mx.symbol.Activation(data=fc, name='sigmoid', act_type="sigmoid") |
| sym = tail_neg_symbol(relu, sigmoid) |
| |
| syms.append(sym) |
| attrs.append([]) |
| excluded_attrs.append([]) |
| return syms, attrs, excluded_attrs |
| |
| @with_seed() |
| def test_pos_single_conv(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = single_conv(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = single_conv(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_conv_relu(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = conv_relu(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = conv_relu(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_conv_bn(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = conv_bn(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = conv_bn(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_conv_add(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = conv_add(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = conv_add(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_conv_add2(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = conv_add2(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = conv_add2(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_conv_bn_relu(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = conv_bn_relu(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = conv_bn_relu(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_conv_bn_sum_relu(): |
| for data_shape in DATA_SHAPE: |
| net, attrs = conv_bn_sum_relu(False, data_shape) |
| check_fusion(net, data_shape, attrs) |
| net, attrs = conv_bn_sum_relu(True, data_shape) |
| check_fusion(net, data_shape, attrs) |
| |
| @with_seed() |
| def test_pos_single_concat(): |
| for data_shape in DATA_SHAPE: |
| for out_type in ('uint8', 'int8', 'auto'): |
| net = single_concat(data_shape, 2, 1) |
| check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) |
| check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) |
| net = single_concat(data_shape, 4, 2) |
| check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) |
| check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) |
| net = single_concat(data_shape, 4, 3) |
| check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) |
| check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) |
| |
| @with_seed() |
| def test_pos_concat_scale_align(): |
| for data_shape in DATA_SHAPE: |
| for out_type in ('uint8', 'int8', 'auto'): |
| net = concat_scale_align(data_shape) |
| check_quantize(net, data_shape, out_type, check_calibration=True, check_scale_align=True) |
| check_quantize(net, data_shape, out_type, check_calibration=True, check_scale_align=True, gluon_forward=True) |
| |
| @with_seed() |
| def test_neg_conv_bn(): |
| for data_shape in DATA_SHAPE: |
| syms, attrs, excluded_attrs = neg_conv_bn(data_shape) |
| check_neg_fusion(syms, attrs, excluded_attrs, data_shape) |
| |
| @with_seed() |
| def test_neg_conv_relu(): |
| for data_shape in DATA_SHAPE: |
| syms, attrs, excluded_attrs = neg_conv_relu(data_shape) |
| check_neg_fusion(syms, attrs, excluded_attrs, data_shape) |
| |
| @with_seed() |
| def test_neg_conv_add(): |
| for data_shape in DATA_SHAPE: |
| syms, attrs, excluded_attrs = neg_conv_add(data_shape) |
| check_neg_fusion(syms, attrs, excluded_attrs, data_shape) |
| |
| @with_seed() |
| def test_neg_conv_bn_relu(): |
| for data_shape in DATA_SHAPE: |
| syms, attrs, excluded_attrs = neg_conv_bn_relu(data_shape) |
| check_neg_fusion(syms, attrs, excluded_attrs, data_shape) |
| |
| @with_seed() |
| def test_neg_conv_bn_add_relu(): |
| for data_shape in DATA_SHAPE: |
| syms, attrs, excluded_attrs = neg_conv_bn_add_relu(data_shape) |
| check_neg_fusion(syms, attrs, excluded_attrs, data_shape) |
| |
| @with_seed() |
| def test_single_fc(): |
| for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): |
| syms, attrs = single_fc(no_bias, dshape, flatten) |
| if flatten is True: |
| check_fusion(syms, dshape, attrs, name='fc', check_quantization=True) |
| else: |
| check_fusion(syms, dshape, attrs, name='fc', check_quantization=False) |
| |
| |
| @with_seed() |
| def test_fc_relu(): |
| for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): |
| syms, attrs = fc_relu(no_bias, dshape, flatten) |
| if flatten is True: |
| check_fusion(syms, dshape, attrs, name='fc', check_quantization=True) |
| else: |
| check_fusion(syms, dshape, attrs, name='fc', check_quantization=False) |
| |
| @with_seed() |
| def test_neg_fc_relu(): |
| for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): |
| syms, attrs, excluded_attrs = neg_fc_relu(no_bias, dshape, flatten) |
| check_neg_fusion(syms, attrs, excluded_attrs, dshape, name='fc') |
| |
| if __name__ == "__main__": |
| import nose |
| nose.runmodule() |