import numpy as np | |
import mxnet as mx | |
import json | |
import utils | |
import math | |
import sys | |
def calc_complexity(ishape, node): | |
y, x = map(int, eval(node['param']['kernel'])) | |
N = int(node['param']['num_filter']) | |
C, Y, X = ishape | |
return x*(N+C)*X*Y, x*y*N*C*X*Y | |
def calc_eigenvalue(model, node): | |
W = model.arg_params[node['name'] + '_weight'].asnumpy() | |
N, C, y, x = W.shape | |
W = W.transpose((1,2,0,3)).reshape((C*y, -1)) | |
U, D, Q = np.linalg.svd(W, full_matrices=False) | |
return D | |
def get_ranksel(model, ratio): | |
conf = json.loads(model.symbol.tojson()) | |
_, output_shapes, _ = model.symbol.get_internals().infer_shape(data=(1,3,224,224)) | |
out_names = model.symbol.get_internals().list_outputs() | |
out_shape_dic = dict(zip(out_names, output_shapes)) | |
nodes = conf['nodes'] | |
nodes = utils.topsort(nodes) | |
C = [] | |
D = [] | |
S = [] | |
conv_names = [] | |
EC = 0 | |
for node in nodes: | |
if node['op'] == 'Convolution': | |
input_nodes = [nodes[int(j[0])] for j in node['inputs']] | |
data = [input_node for input_node in input_nodes\ | |
if not input_node['name'].startswith(node['name'])][0] | |
if utils.is_input(data): | |
ishape = (3, 224, 224) | |
else: | |
ishape = out_shape_dic[data['name'] + '_output'][1:] | |
C.append(calc_complexity(ishape, node)) | |
D.append(int(node['param']['num_filter'])) | |
S.append(calc_eigenvalue(model, node)) | |
conv_names.append(node['name']) | |
EC += C[-1][1] | |
for s in S: | |
ss = sum(s) | |
for i in xrange(1, len(s)): | |
s[i] += s[i-1] | |
n = len(C) | |
EC /= ratio | |
dp = [{}, {}] | |
dpc = [{} for _ in xrange(n)] | |
now, nxt = 0, 1 | |
dp[now][0] = 0 | |
for i in xrange(n): | |
dp[nxt] = {} | |
sys.stdout.flush() | |
for now_c, now_v in dp[now].items(): | |
for d in xrange(min(len(S[i]), D[i])): | |
nxt_c = now_c + (d+1)*C[i][0] | |
if nxt_c > EC: | |
continue | |
nxt_v = dp[now][now_c] + math.log(S[i][d]) | |
if dp[nxt].has_key(nxt_c): | |
if nxt_v > dp[nxt][nxt_c]: | |
dp[nxt][nxt_c] = nxt_v | |
dpc[i][nxt_c] = (d,now_c) | |
else: | |
dp[nxt][nxt_c] = nxt_v | |
dpc[i][nxt_c] = (d,now_c) | |
now, nxt = nxt, now | |
maxv = -1e9 | |
target_c = 0 | |
for c,v in dp[now].items(): | |
assert c <= EC, 'False' | |
if v > maxv: | |
maxv = v | |
target_c = c | |
res = [0]*n | |
nowc = target_c | |
for i in xrange(n-1,-1,-1): | |
res[i] = dpc[i][nowc][0] + 1 | |
nowc = dpc[i][nowc][1] | |
return dict(zip(conv_names, res)) |