| from load_model import load_checkpoint |
| from save_model import save_checkpoint |
| |
| |
| def combine_model(prefix1, epoch1, prefix2, epoch2, prefix_out, epoch_out): |
| args1, auxs1 = load_checkpoint(prefix1, epoch1) |
| args2, auxs2 = load_checkpoint(prefix2, epoch2) |
| arg_names = args1.keys() + args2.keys() |
| aux_names = auxs1.keys() + auxs2.keys() |
| args = dict() |
| for arg in arg_names: |
| if arg in args1: |
| args[arg] = args1[arg] |
| else: |
| args[arg] = args2[arg] |
| auxs = dict() |
| for aux in aux_names: |
| if aux in auxs1: |
| auxs[aux] = auxs1[aux] |
| else: |
| auxs[aux] = auxs2[aux] |
| save_checkpoint(prefix_out, epoch_out, args, auxs) |