import mxnet as mx | |
import copy | |
import json | |
import ast | |
def load_model(args): | |
devs = mx.cpu() if args.gpus == None else [mx.gpu(int(i)) for i in args.gpus.split(',')] | |
return mx.model.FeedForward.load(args.model, args.load_epoch, ctx=devs) | |
def topsort(nodes): | |
n = len(nodes) | |
deg = [0]*n | |
g = [[] for _ in xrange(n)] | |
for i,node in enumerate(nodes): | |
if node.has_key('inputs'): | |
for j in node['inputs']: | |
deg[i] += 1 | |
g[j[0]].append(i) | |
from collections import deque | |
q = deque([i for i in xrange(n) if deg[i]==0]) | |
res = [] | |
for its in xrange(n): | |
i = q.popleft() | |
res.append(nodes[i]) | |
for j in g[i]: | |
deg[j] -= 1 | |
if deg[j] == 0: | |
q.append(j) | |
new_ids=dict([(node['name'],i) for i,node in enumerate(res)]) | |
for node in res: | |
if node.has_key('inputs'): | |
for j in node['inputs']: | |
j[0]=new_ids[nodes[j[0]]['name']] | |
return res | |
def is_input(node): | |
name = node['name'] | |
return len(node['inputs']) == 0 and ('weight' not in name) and ('bias' not in name) and ('label' not in name) | |
def sym_factory(node, data): | |
name = node['name'] | |
params = {} | |
if 'param' in node: | |
for k, v in node['param'].items(): | |
try: | |
params[k] = ast.literal_eval(v) | |
except ValueError, e: | |
params[k] = v | |
return getattr(mx.symbol, node['op'])(data=data, name=name, **params) | |
def replace_conv_layer(layer_name, old_model, sym_handle, arg_handle): | |
conf = json.loads(old_model.symbol.tojson()) | |
sym_dict = {} | |
nodes = conf['nodes'] | |
nodes = topsort(nodes) | |
res_sym = None | |
new_model = old_model | |
for i,node in enumerate(nodes): | |
sym = None | |
if is_input(node): | |
sym = mx.symbol.Variable(name='data') | |
elif node['op'] != 'null': | |
input_nodes = [nodes[int(j[0])] for j in node['inputs']] | |
datas = [input_node['name'] for input_node in input_nodes\ | |
if not input_node['name'].startswith(node['name'])] | |
try: | |
data=sym_dict[datas[0]] | |
except Exception, e: | |
print 'can not find symbol %s'%(datas[0]) | |
raise e | |
if node['name'] == layer_name: | |
sym = sym_handle(data, node) | |
else: | |
sym = sym_factory(node, data) | |
if sym: | |
sym_dict[node['name']] = sym | |
res_sym = sym | |
arg_params = copy.deepcopy(old_model.arg_params) | |
if layer_name: | |
arg_shapes, _, _ = res_sym.infer_shape(data=(1,3,224,224)) | |
arg_names = res_sym.list_arguments() | |
arg_shape_dic = dict(zip(arg_names, arg_shapes)) | |
try: | |
arg_handle(arg_shape_dic, arg_params) | |
except Exception, e: | |
raise Exception('Exception in arg_handle') | |
new_model = mx.model.FeedForward( | |
symbol=res_sym, | |
ctx=old_model.ctx, | |
num_epoch=1, | |
epoch_size=old_model.epoch_size, | |
optimizer='sgd', | |
initializer=old_model.initializer, | |
numpy_batch_size=old_model.numpy_batch_size, | |
arg_params=arg_params, | |
aux_params=old_model.aux_params, | |
allow_extra_params=True, | |
begin_epoch=old_model.begin_epoch) | |
return new_model |