| import numpy as np | |
| from scipy import linalg as LA | |
| import mxnet as mx | |
| import argparse | |
| import utils | |
| import pdb | |
| def fc_decomposition(model, args): | |
| W = model.arg_params[args.layer+'_weight'].asnumpy() | |
| b = model.arg_params[args.layer+'_bias'].asnumpy() | |
| W = W.reshape((W.shape[0],-1)) | |
| b = b.reshape((b.shape[0],-1)) | |
| u, s, v = LA.svd(W, full_matrices=False) | |
| s = np.diag(s) | |
| t = u.dot(s.dot(v)) | |
| rk = args.K | |
| P = u[:,:rk] | |
| Q = s[:rk,:rk].dot(v[:rk,:]) | |
| name1 = args.layer + '_red' | |
| name2 = args.layer + '_rec' | |
| def sym_handle(data, node): | |
| W1, W2 = Q, P | |
| sym1 = mx.symbol.FullyConnected(data=data, num_hidden=W1.shape[0], no_bias=True, name=name1) | |
| sym2 = mx.symbol.FullyConnected(data=sym1, num_hidden=W2.shape[0], no_bias=False, name=name2) | |
| return sym2 | |
| def arg_handle(arg_shape_dic, arg_params): | |
| W1, W2 = Q, P | |
| W1 = W1.reshape(arg_shape_dic[name1+'_weight']) | |
| weight1 = mx.ndarray.array(W1) | |
| W2 = W2.reshape(arg_shape_dic[name2+'_weight']) | |
| b2 = b.reshape(arg_shape_dic[name2+'_bias']) | |
| weight2 = mx.ndarray.array(W2) | |
| bias2 = mx.ndarray.array(b2) | |
| arg_params[name1 + '_weight'] = weight1 | |
| arg_params[name2 + '_weight'] = weight2 | |
| arg_params[name2 + '_bias'] = bias2 | |
| new_model = utils.replace_conv_layer(args.layer, model, sym_handle, arg_handle) | |
| return new_model | |
| def main(): | |
| model = utils.load_model(args) | |
| new_model = fc_decomposition(model, args) | |
| new_model.save(args.save_model) | |
| if __name__ == '__main__': | |
| parser=argparse.ArgumentParser() | |
| parser.add_argument('-m', '--model', help='the model to speed up') | |
| parser.add_argument('-g', '--gpus', default='0', help='the gpus to be used in ctx') | |
| parser.add_argument('--load-epoch',type=int,default=1) | |
| parser.add_argument('--layer') | |
| parser.add_argument('--K', type=int) | |
| parser.add_argument('--save-model') | |
| args = parser.parse_args() | |
| main() |