blob: 973debd94b8bca94fec88b3dc882ae1ca0f5ec82 [file] [log] [blame]
"""Converting tensorflow checkpoint file to key-val pkl file."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import os
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
import inception_v4
import inception_v3
try:
import cPickle as pickle
except ImportError:
import pickle
FLAGS = None
def rename(name, suffix):
p = name.rfind("/")
if p == -1:
print("Bad name=%s" % name)
return name[0 : p + 1] + suffix
def convert(model, file_name):
if model == "v3":
net, _ = inception_v3.create_net()
else:
net, _ = inception_v4.create_net()
params = {"SINGA_VERSION": 1101}
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
for pname, pval in zip(net.param_names(), net.param_values()):
if "weight" in pname:
val = reader.get_tensor(rename(pname, "weights"))
if "Conv" in pname:
val = val.transpose((3, 2, 0, 1))
val = val.reshape((val.shape[0], -1))
elif "bias" in pname:
val = reader.get_tensor(rename(pname, "biases"))
elif "mean" in pname:
val = reader.get_tensor(rename(pname, "moving_mean"))
elif "var" in pname:
val = reader.get_tensor(rename(pname, "moving_variance"))
elif "beta" in pname:
val = reader.get_tensor(pname)
elif "gamma" in pname:
val = np.ones(pval.shape)
else:
print("not matched param %s" % pname)
assert val.shape == pval.shape, (
"the shapes not match ",
val.shape,
pval.shape,
)
params[pname] = val.astype(np.float32)
print("converting:", pname, pval.shape)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
if "weights" in key:
key = rename(key, "weight")
elif "biases" in key:
key = rename(key, "bias")
elif "moving_mean" in key:
key = rename(key, "mean")
elif "moving_variance" in key:
key = rename(key, "var")
if key not in params:
print("key=%s not in the net" % key)
"""
for key in var_to_shape_map:
print("tensor_name: ", key, var_to_shape_map[key])
"""
with open(os.path.splitext(file_name)[0] + ".pickle", "wb") as fd:
pickle.dump(params, fd)
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print(
"It's likely that your checkpoint file has been compressed "
"with SNAPPY."
)
if "Data loss" in str(e) and (
any([e in file_name for e in [".index", ".meta", ".data"]])
):
proposed_file = ".".join(file_name.split(".")[0:-1])
v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide
the filename *prefix*. Try removing the '.' and extension. Try:
inspect checkpoint --file_name = {}"""
print(v2_file_error_template.format(proposed_file))
def main(unused_argv):
if not FLAGS.file_name:
print("Usage: convert.py --file_name=checkpoint_file_name ")
sys.exit(1)
else:
convert(FLAGS.model, FLAGS.file_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument("model", choices=["v3", "v4"], help="inception version")
parser.add_argument("file_name", help="Checkpoint path")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)