blob: 6406b704ea5aca08d2d0943e6e25f9620b61b4eb [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converting tensorflow checkpoint file to key-val pkl file."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import os
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
import inception_v4
import inception_v3
try:
import cPickle as pickle
except ModuleNotFoundError:
import pickle
FLAGS = None
def rename(name, suffix):
p = name.rfind('/')
if p == -1:
print('Bad name=%s' % name)
return name[0:p+1] + suffix
def convert(model, file_name):
if model == 'v3':
net, _ = inception_v3.create_net()
else:
net, _ = inception_v4.create_net()
params = {'SINGA_VERSION': 1101}
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
for pname, pval in zip(net.param_names(), net.param_values()):
if 'weight' in pname:
val = reader.get_tensor(rename(pname, 'weights'))
if 'Conv' in pname:
val = val.transpose((3, 2, 0, 1))
val = val.reshape((val.shape[0], -1))
elif 'bias' in pname:
val = reader.get_tensor(rename(pname, 'biases'))
elif 'mean' in pname:
val = reader.get_tensor(rename(pname, 'moving_mean'))
elif 'var' in pname:
val = reader.get_tensor(rename(pname, 'moving_variance'))
elif 'beta' in pname:
val= reader.get_tensor(pname)
elif 'gamma' in pname:
val = np.ones(pval.shape)
else:
print('not matched param %s' % pname)
assert val.shape == pval.shape, ('the shapes not match ',
val.shape, pval.shape)
params[pname] = val.astype(np.float32)
print('converting:', pname, pval.shape)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
if 'weights' in key:
key = rename(key, 'weight')
elif 'biases' in key:
key = rename(key, 'bias')
elif 'moving_mean' in key:
key = rename(key, 'mean')
elif 'moving_variance' in key:
key = rename(key, 'var')
if key not in params:
print('key=%s not in the net' % key)
'''
for key in var_to_shape_map:
print("tensor_name: ", key, var_to_shape_map[key])
'''
with open(os.path.splitext(file_name)[0] + '.pickle', 'wb') as fd:
pickle.dump(params, fd)
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
if ("Data loss" in str(e) and
(any([e in file_name for e in [".index", ".meta", ".data"]]))):
proposed_file = ".".join(file_name.split(".")[0:-1])
v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide
the filename *prefix*. Try removing the '.' and extension. Try:
inspect checkpoint --file_name = {}"""
print(v2_file_error_template.format(proposed_file))
def main(unused_argv):
if not FLAGS.file_name:
print("Usage: convert.py --file_name=checkpoint_file_name ")
sys.exit(1)
else:
convert(FLAGS.model, FLAGS.file_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument("model", choices=['v3', 'v4'], help="inception version")
parser.add_argument("file_name", help="Checkpoint path")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)