blob: 93b2e12d1246517c85529f1baf6bc4daf9c2081d [file] [log] [blame]
from __future__ import print_function
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import copy
import json
import ast
from six.moves import xrange
def load_model(args):
devs = mx.cpu() if args.gpus == None else [mx.gpu(int(i)) for i in args.gpus.split(',')]
return mx.model.FeedForward.load(args.model, args.load_epoch, ctx=devs)
def topsort(nodes):
n = len(nodes)
deg = [0]*n
g = [[] for _ in xrange(n)]
for i,node in enumerate(nodes):
if 'inputs' in node:
for j in node['inputs']:
deg[i] += 1
g[j[0]].append(i)
from collections import deque
q = deque([i for i in xrange(n) if deg[i]==0])
res = []
for its in xrange(n):
i = q.popleft()
res.append(nodes[i])
for j in g[i]:
deg[j] -= 1
if deg[j] == 0:
q.append(j)
new_ids=dict([(node['name'],i) for i,node in enumerate(res)])
for node in res:
if 'inputs' in node:
for j in node['inputs']:
j[0]=new_ids[nodes[j[0]]['name']]
return res
def is_input(node):
name = node['name']
return len(node['inputs']) == 0 and ('weight' not in name) and ('bias' not in name) and ('label' not in name)
def sym_factory(node, data):
name = node['name']
params = {}
if 'param' in node:
for k, v in node['param'].items():
try:
params[k] = ast.literal_eval(v)
except ValueError as e:
params[k] = v
return getattr(mx.symbol, node['op'])(data=data, name=name, **params)
def replace_conv_layer(layer_name, old_model, sym_handle, arg_handle):
conf = json.loads(old_model.symbol.tojson())
sym_dict = {}
nodes = conf['nodes']
nodes = topsort(nodes)
res_sym = None
new_model = old_model
for i,node in enumerate(nodes):
sym = None
if is_input(node):
sym = mx.symbol.Variable(name='data')
elif node['op'] != 'null':
input_nodes = [nodes[int(j[0])] for j in node['inputs']]
datas = [input_node['name'] for input_node in input_nodes\
if not input_node['name'].startswith(node['name'])]
try:
data=sym_dict[datas[0]]
except Exception as e:
print('can not find symbol %s'%(datas[0]))
raise e
if node['name'] == layer_name:
sym = sym_handle(data, node)
else:
sym = sym_factory(node, data)
if sym:
sym_dict[node['name']] = sym
res_sym = sym
arg_params = copy.deepcopy(old_model.arg_params)
if layer_name:
arg_shapes, _, _ = res_sym.infer_shape(data=(1,3,224,224))
arg_names = res_sym.list_arguments()
arg_shape_dic = dict(zip(arg_names, arg_shapes))
try:
arg_handle(arg_shape_dic, arg_params)
except Exception as e:
raise Exception('Exception in arg_handle')
new_model = mx.model.FeedForward(
symbol=res_sym,
ctx=old_model.ctx,
num_epoch=1,
epoch_size=old_model.epoch_size,
optimizer='sgd',
initializer=old_model.initializer,
numpy_batch_size=old_model.numpy_batch_size,
arg_params=arg_params,
aux_params=old_model.aux_params,
allow_extra_params=True,
begin_epoch=old_model.begin_epoch)
return new_model