blob: 5518dda4a9891707ca8f5aaba4929a21245a8623 [file] [log] [blame]
from load_model import load_checkpoint
from save_model import save_checkpoint
def combine_model(prefix1, epoch1, prefix2, epoch2, prefix_out, epoch_out):
args1, auxs1 = load_checkpoint(prefix1, epoch1)
args2, auxs2 = load_checkpoint(prefix2, epoch2)
arg_names = args1.keys() + args2.keys()
aux_names = auxs1.keys() + auxs2.keys()
args = dict()
for arg in arg_names:
if arg in args1:
args[arg] = args1[arg]
else:
args[arg] = args2[arg]
auxs = dict()
for aux in aux_names:
if aux in auxs1:
auxs[aux] = auxs1[aux]
else:
auxs[aux] = auxs2[aux]
save_checkpoint(prefix_out, epoch_out, args, auxs)