| import mxnet as mx |
| |
| |
| def save_checkpoint(prefix, epoch, arg_params, aux_params): |
| """Checkpoint the model data into file. |
| :param prefix: Prefix of model name. |
| :param epoch: The epoch number of the model. |
| :param arg_params: dict of str to NDArray |
| Model parameter, dict of name to NDArray of net's weights. |
| :param aux_params: dict of str to NDArray |
| Model parameter, dict of name to NDArray of net's auxiliary states. |
| :return: None |
| prefix-epoch.params will be saved for parameters. |
| """ |
| save_dict = {('arg:%s' % k) : v for k, v in arg_params.items()} |
| save_dict.update({('aux:%s' % k) : v for k, v in aux_params.items()}) |
| param_name = '%s-%04d.params' % (prefix, epoch) |
| mx.nd.save(param_name, save_dict) |