blob: ee3eca91f9344aad1c55f5e40c4feb0ee28c5faf [file] [log] [blame]
import numpy as np
import mxnet as mx
import json
import utils
import math
import sys
def calc_complexity(ishape, node):
y, x = map(int, eval(node['param']['kernel']))
N = int(node['param']['num_filter'])
C, Y, X = ishape
return x*(N+C)*X*Y, x*y*N*C*X*Y
def calc_eigenvalue(model, node):
W = model.arg_params[node['name'] + '_weight'].asnumpy()
N, C, y, x = W.shape
W = W.transpose((1,2,0,3)).reshape((C*y, -1))
U, D, Q = np.linalg.svd(W, full_matrices=False)
return D
def get_ranksel(model, ratio):
conf = json.loads(model.symbol.tojson())
_, output_shapes, _ = model.symbol.get_internals().infer_shape(data=(1,3,224,224))
out_names = model.symbol.get_internals().list_outputs()
out_shape_dic = dict(zip(out_names, output_shapes))
nodes = conf['nodes']
nodes = utils.topsort(nodes)
C = []
D = []
S = []
conv_names = []
EC = 0
for node in nodes:
if node['op'] == 'Convolution':
input_nodes = [nodes[int(j[0])] for j in node['inputs']]
data = [input_node for input_node in input_nodes\
if not input_node['name'].startswith(node['name'])][0]
if utils.is_input(data):
ishape = (3, 224, 224)
else:
ishape = out_shape_dic[data['name'] + '_output'][1:]
C.append(calc_complexity(ishape, node))
D.append(int(node['param']['num_filter']))
S.append(calc_eigenvalue(model, node))
conv_names.append(node['name'])
EC += C[-1][1]
for s in S:
ss = sum(s)
for i in xrange(1, len(s)):
s[i] += s[i-1]
n = len(C)
EC /= ratio
dp = [{}, {}]
dpc = [{} for _ in xrange(n)]
now, nxt = 0, 1
dp[now][0] = 0
for i in xrange(n):
dp[nxt] = {}
sys.stdout.flush()
for now_c, now_v in dp[now].items():
for d in xrange(min(len(S[i]), D[i])):
nxt_c = now_c + (d+1)*C[i][0]
if nxt_c > EC:
continue
nxt_v = dp[now][now_c] + math.log(S[i][d])
if dp[nxt].has_key(nxt_c):
if nxt_v > dp[nxt][nxt_c]:
dp[nxt][nxt_c] = nxt_v
dpc[i][nxt_c] = (d,now_c)
else:
dp[nxt][nxt_c] = nxt_v
dpc[i][nxt_c] = (d,now_c)
now, nxt = nxt, now
maxv = -1e9
target_c = 0
for c,v in dp[now].items():
assert c <= EC, 'False'
if v > maxv:
maxv = v
target_c = c
res = [0]*n
nowc = target_c
for i in xrange(n-1,-1,-1):
res[i] = dpc[i][nowc][0] + 1
nowc = dpc[i][nowc][1]
return dict(zip(conv_names, res))