blob: 1c9886973bd6d5f58f196d08356c24e56fee6fcd [file] [log] [blame]
import mxnet as mx
def save_checkpoint(prefix, epoch, arg_params, aux_params):
"""Checkpoint the model data into file.
:param prefix: Prefix of model name.
:param epoch: The epoch number of the model.
:param arg_params: dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
:param aux_params: dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
:return: None
prefix-epoch.params will be saved for parameters.
"""
save_dict = {('arg:%s' % k) : v for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k) : v for k, v in aux_params.items()})
param_name = '%s-%04d.params' % (prefix, epoch)
mx.nd.save(param_name, save_dict)