blob: 8786647678d97d2cccf748b6389fc0b0195eef76 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Extract the net parameters from the pytorch file and store them
as python dict using cPickle. Must install pytorch.
'''
import torch.utils.model_zoo as model_zoo
import numpy as np
from argparse import ArgumentParser
import model
try:
import cPickle as pickle
except ModuleNotFoundError:
import pickle
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
def rename(pname):
p1 = pname.find('/')
p2 = pname.rfind('/')
assert p1 != -1 and p2 != -1, 'param name = %s is not correct' % pname
if 'gamma' in pname:
suffix = 'weight'
elif 'beta' in pname:
suffix = 'bias'
elif 'mean' in pname:
suffix = 'running_mean'
elif 'var' in pname:
suffix = 'running_var'
else:
suffix = pname[p2 + 1:]
return pname[p1+1:p2] + '.' + suffix
if __name__ == '__main__':
parser = ArgumentParser(description='Convert params from torch to python'
'dict. ')
parser.add_argument("depth", type=int, choices=[11, 13, 16, 19])
parser.add_argument("outfile")
parser.add_argument("--batchnorm", action='store_true',
help='use batchnorm or not')
args = parser.parse_args()
net = model.create_net(args.depth, 1000, args.batchnorm)
url = 'vgg%d' % args.depth
if args.batchnorm:
url += '_bn'
torch_dict = model_zoo.load_url(model_urls[url])
params = {'SINGA_VERSION': 1101}
# params = net.param_values()
for pname, pval in zip(net.param_names(), net.param_values()):
torch_name = rename(pname)
if torch_name in torch_dict:
ary = torch_dict[torch_name].numpy()
ary = np.array(ary, dtype=np.float32)
if len(ary.shape) == 4:
params[pname] = np.reshape(ary, (ary.shape[0], -1))
else:
params[pname] = np.transpose(ary)
else:
print('param=%s is missing in the ckpt file' % pname)
assert pval.shape == params[pname].shape,\
'shape mismatch for %s' % pname
with open(args.outfile, 'wb') as fd:
pickle.dump(params, fd)