save load for module, loss, metric (#5418)
diff --git a/docs/api/python/module.md b/docs/api/python/module.md
index 5a0198d..1f2553d 100644
--- a/docs/api/python/module.md
+++ b/docs/api/python/module.md
@@ -8,28 +8,47 @@
The module API, defined in the `module` (or simply `mod`) package, provides an
intermediate and high-level interface for performing computation with a
-`Symbol`. One can roughly think a module is a machine which can execute a
-program defined by a `Symbol`.
+`Symbol` or `Loss`. One can roughly think a module is a machine which can execute a
+program defined by a `Symbol` or `Loss`.
-The class `module.Module` is a commonly used module, which accepts a `Symbol` as
+The class `module.Module` is a commonly used module, which accepts a `Symbol` or `Loss` as
the input:
```python
data = mx.symbol.Variable('data')
+label = mx.sym.Variable('label')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10)
-out = mx.symbol.SoftmaxOutput(fc2, name = 'softmax')
-mod = mx.mod.Module(out) # create a module by given a Symbol
+
+loss = mx.loss.softmax_cross_entropy_loss(fc2, label)
+mod = mx.mod.Module(loss, data_names=('data',))
```
+Alternatively, if you only want to do prediction or want to compute loss manually outside
+of module and feed gradient back using Module.backward(out_grads=...), you can also directly
+feed a `Symbol` into module:
+
+```python
+data = mx.symbol.Variable('data')
+label = mx.sym.Variable('label')
+fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
+act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10)
+
+mod = mx.mod.Module(fc2, data_names=('data',))
+```
+
+
Assume there is a valid MXNet data iterator `data`. We can initialize the
module:
```python
+# allocate memory by given input shapes
mod.bind(data_shapes=data.provide_data,
- label_shapes=data.provide_label) # create memory by given input shapes
-mod.init_params() # initial parameters with the default random initializer
+ label_shapes=data.provide_label)
+# initial parameters with uniform distribution in [-0.05, 0.05]
+mod.init_params(mx.init.Uniform(0.05))
```
Now the module is able to compute. We can call high-level API to train and
@@ -43,9 +62,10 @@
or use intermediate APIs to perform step-by-step computations
```python
-mod.forward(data_batch) # forward on the provided data batch
-mod.backward() # backward to calculate the gradients
-mod.update() # update parameters using the default optimizer
+for data_batch in data:
+ mod.forward(data_batch) # forward on the provided data batch
+ mod.backward() # backward to calculate the gradients
+ mod.update() # update parameters using the default optimizer
```
A detailed tutorial is available at [http://mxnet.io/tutorials/python/module.html](http://mxnet.io/tutorials/python/module.html).
diff --git a/python/mxnet/loss.py b/python/mxnet/loss.py
index 8047761..d5ec6ad 100644
--- a/python/mxnet/loss.py
+++ b/python/mxnet/loss.py
@@ -3,8 +3,9 @@
""" losses for training neural networks """
from __future__ import absolute_import
-from .base import numeric_types, string_types
+from .base import numeric_types, string_types, __version__
from . import symbol
+from . import registry
from . import metric as _metric
@@ -58,16 +59,16 @@
return metric
-class Loss(object):
+class BaseLoss(object):
"""Base class for all loss layers.
Parameters
----------
- loss : Symbol
+ losses : Symbol
a symbol whose output is the loss. Can be a scalar value
or an array. If loss is an array, the sum of its elements
will be the final loss.
- output : Symbol
+ outputs : Symbol
output of the model when predicting.
label_names : list of str
names of label variables. labels are used for training
@@ -84,22 +85,47 @@
Returns
-------
- loss : Loss
+ loss : BaseLoss
created loss
"""
- def __init__(self, loss, output, label_names, name, metric=None,
- output_head_grad=False, loss_head_grad=False):
+ def __init__(self, losses, outputs, label_names, name, metric=None,
+ output_head_grad=False, loss_head_grad=False, **kwargs):
+ if losses is None:
+ sym = list(symbol.load_json(kwargs['__group_sym__']))
+ num = kwargs['__num_losses__']
+ losses = symbol.Group(sym[:num])
+ outputs = symbol.Group(sym[num:])
+ if metric is not None:
+ metric = _metric.create(**metric)
+
+ losses = symbol.Group(list(losses))
+ outputs = symbol.Group(list(outputs))
+ self._kwargs = kwargs
+ self._kwargs.update({
+ 'loss': self.__class__.__name__,
+ 'losses': None,
+ 'outputs': None,
+ 'label_names': label_names,
+ 'name': name,
+ 'metric': metric.get_config() if metric is not None else None,
+ 'output_head_grad': output_head_grad,
+ 'loss_head_grad': loss_head_grad,
+ '__group_sym__': symbol.Group(list(losses)+list(outputs)).tojson(),
+ '__num_losses__': len(list(losses)),
+ '__type__': 'loss',
+ '__version__': __version__})
+
if not loss_head_grad:
self._loss_symbol = symbol.Group([symbol.make_loss(x, name=x.name+'_loss')
- for x in loss])
+ for x in losses])
else:
- self._loss_symbol = loss
+ self._loss_symbol = losses
if not output_head_grad:
self._output_symbol = symbol.Group([symbol.stop_gradient(x, name=x.name+'_out')
- for x in output])
+ for x in outputs])
else:
- self._output_symbol = output
+ self._output_symbol = outputs
self._label_names = list(label_names) if label_names else []
self._name = name
@@ -136,6 +162,17 @@
"""Metric for evaluation"""
return self._metric
+ def get_config(self):
+ """get configs for serialization"""
+ return self._kwargs.copy()
+
+
+# pylint: disable=invalid-name
+register = registry.get_register_func(BaseLoss, 'loss')
+create = registry.get_create_func(BaseLoss, 'loss')
+register(BaseLoss)
+# pylint: enable=invalid-name
+
def custom_loss(loss, output, label_names, extra_outputs=(),
weight=None, sample_weight=None, name='loss',
@@ -182,7 +219,7 @@
Returns
-------
- loss : Loss
+ loss : BaseLoss
created loss
"""
label_names = list(label_names)
@@ -193,7 +230,7 @@
if i not in loss.list_arguments()]
loss = _apply_weight(loss, weight=weight, sample_weight=sample_weight)
loss._set_attr(name=name)
- return Loss(loss, output, label_names, name, **kwargs)
+ return BaseLoss(loss, output, label_names, name, **kwargs)
def multi_loss(losses, extra_outputs=(), name='multi'):
@@ -202,7 +239,7 @@
Parameters
----------
- losses : list of Loss
+ losses : list of BaseLoss
a list of individual losses with no extra outputs.
extra_outputs : list of Symbol
extra outputs for predition but not used for evaluating
@@ -213,7 +250,7 @@
Returns
-------
- loss : Loss
+ loss : BaseLoss
created loss
"""
loss = symbol.Group(sum([list(i.loss_symbol) for i in losses], []))
@@ -223,8 +260,8 @@
for name in i.label_names:
if name not in label_names:
label_names.append(name)
- ret = Loss(loss, output, label_names, name,
- output_head_grad=True, loss_head_grad=True)
+ ret = BaseLoss(loss, output, label_names, name,
+ output_head_grad=True, loss_head_grad=True)
del ret.metric.metrics[:]
for i in losses:
ret.metric.add(i.metric)
@@ -272,7 +309,7 @@
Returns
-------
- loss : Loss
+ loss : BaseLoss
created loss
"""
metric = _parse_metric(metric, output, label)
@@ -286,7 +323,7 @@
label_names = [x for x in loss.list_arguments()
if x not in output.list_arguments()]
- return Loss(loss, outputs, label_names, name, metric=metric, **kwargs)
+ return BaseLoss(loss, outputs, label_names, name, metric=metric, **kwargs)
def l1_loss(output, label, extra_outputs=(), name='l1',
@@ -329,7 +366,7 @@
Returns
-------
- loss : Loss
+ loss : BaseLoss
created loss
"""
metric = _parse_metric(metric, output, label)
@@ -343,12 +380,12 @@
label_names = [x for x in loss.list_arguments()
if x not in output.list_arguments()]
- return Loss(loss, outputs, label_names, name, metric=metric, **kwargs)
+ return BaseLoss(loss, outputs, label_names, name, metric=metric, **kwargs)
-def cross_entropy_loss(output, label, sparse_label=True, axis=1,
- extra_outputs=(), name='ce', weight=None,
- sample_weight=None, metric='acc', **kwargs):
+def softmax_cross_entropy_loss(output, label, sparse_label=True, axis=1,
+ extra_outputs=(), name='ce', weight=None,
+ sample_weight=None, metric='acc', **kwargs):
"""Compute the softmax cross entropy loss.
If sparse_label is True, label should contain integer category indicators:
@@ -395,7 +432,7 @@
Returns
-------
- loss : Loss
+ loss : BaseLoss
created loss
"""
metric = _parse_metric(metric, output, label)
@@ -418,4 +455,4 @@
label_names = [x for x in loss.list_arguments()
if x not in output.list_arguments()]
- return Loss(loss, outputs, label_names, name, metric=metric, **kwargs)
+ return BaseLoss(loss, outputs, label_names, name, metric=metric, **kwargs)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 3bd26b7..6b7f87a 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -7,9 +7,12 @@
from collections import OrderedDict
import numpy
+import pickle
-from .base import numeric_types
+from . import base
+from .base import numeric_types, string_types
from . import ndarray
+from . import registry
def check_label_shapes(labels, preds, shape=0):
@@ -24,25 +27,36 @@
raise ValueError("Shape of labels {} does not match shape of "
"predictions {}".format(label_shape, pred_shape))
+
class EvalMetric(object):
"""Base class of all evaluation metrics."""
- def __init__(self, name, num=None, output_names=None, label_names=None):
+ def __init__(self, name, num=None, output_names=None,
+ label_names=None, **kwargs):
self.name = name
self.num = num
self.output_names = output_names
self.label_names = label_names
+ self._kwargs = kwargs
self.reset()
def __str__(self):
return "EvalMetric: {}".format(dict(self.get_name_value()))
- def __getstate__(self):
- return self.__dict__.copy()
-
- def __setstate__(self, state):
- self.__dict__.update(state)
- self.reset()
+ def get_config(self):
+ """Save configurations of metric. Can be recreated
+ from configs with metric.create(**config)
+ """
+ config = self._kwargs.copy()
+ config.update({
+ 'metric': self.__class__.__name__,
+ 'name': self.name,
+ 'num': self.num,
+ 'output_names': self.output_names,
+ 'label_names': self.label_names,
+ '__type__': 'metric',
+ '__version__': base.__version__})
+ return config
def update_dict(self, label, pred):
"""Update the internal evaluation with named label and pred
@@ -119,7 +133,44 @@
value = [value]
return list(zip(name, value))
+# pylint: disable=invalid-name
+register = registry.get_register_func(EvalMetric, 'metric')
+alias = registry.get_alias_func(EvalMetric, 'metric')
+_create = registry.get_create_func(EvalMetric, 'metric')
+# pylint: enable=invalid-name
+
+def create(metric, *args, **kwargs):
+ """Create an evaluation metric.
+
+ Parameters
+ ----------
+ metric : str or callable
+ The name of the metric, or a function
+ providing statistics given pred, label NDArray.
+ *args : list
+ additional arguments to metric constructor
+ **kwargs : dict
+ additional arguments to metric constructor
+
+ Returns
+ -------
+ created metric
+ """
+
+ if callable(metric):
+ return CustomMetric(metric, *args, **kwargs)
+ elif isinstance(metric, list):
+ composite_metric = CompositeEvalMetric()
+ for child_metric in metric:
+ composite_metric.add(create(child_metric, *args, **kwargs))
+ return composite_metric
+
+ return _create(metric, *args, **kwargs)
+
+
+@register
+@alias('composite')
class CompositeEvalMetric(EvalMetric):
"""Manage multiple evaluation metrics."""
@@ -176,10 +227,19 @@
values.extend(value)
return (names, values)
+ def get_config(self):
+ config = super(CompositeEvalMetric, self).get_config()
+ config.update({'metrics': [i.get_config() for i in self.metrics]})
+ return config
+
+
########################
# CLASSIFICATION METRICS
########################
+
+@register
+@alias('acc')
class Accuracy(EvalMetric):
"""Calculate accuracy
@@ -188,8 +248,8 @@
axis : int, default=1
The axis that represents classes
"""
- def __init__(self, axis=1, **kwargs):
- super(Accuracy, self).__init__('accuracy', **kwargs)
+ def __init__(self, axis=1, name='accuracy', **kwargs):
+ super(Accuracy, self).__init__(name, axis=axis, **kwargs)
self.axis = axis
def update(self, labels, preds):
@@ -206,11 +266,14 @@
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
+
+@register
+@alias('top_k_accuracy', 'top_k_acc')
class TopKAccuracy(EvalMetric):
"""Calculate top k predictions accuracy"""
- def __init__(self, top_k=1, **kwargs):
- super(TopKAccuracy, self).__init__('top_k_accuracy', **kwargs)
+ def __init__(self, top_k=1, name='top_k_accuracy', **kwargs):
+ super(TopKAccuracy, self).__init__(name, top_k=top_k, **kwargs)
self.top_k = top_k
assert(self.top_k > 1), 'Please use Accuracy if top_k is no more than 1'
self.name += '_%d' % self.top_k
@@ -234,11 +297,13 @@
self.sum_metric += (pred_label[:, num_classes - 1 - j].flat == label.flat).sum()
self.num_inst += num_samples
+
+@register
class F1(EvalMetric):
"""Calculate the F1 score of a binary classification problem."""
- def __init__(self, **kwargs):
- super(F1, self).__init__('f1', **kwargs)
+ def __init__(self, name='f1', **kwargs):
+ super(F1, self).__init__(name, **kwargs)
def update(self, labels, preds):
check_label_shapes(labels, preds)
@@ -281,6 +346,7 @@
self.num_inst += 1
+@register
class Perplexity(EvalMetric):
"""Calculate perplexity
@@ -291,8 +357,8 @@
counting. usually should be -1. Include
all entries if None.
"""
- def __init__(self, ignore_label, **kwargs):
- super(Perplexity, self).__init__('Perplexity', **kwargs)
+ def __init__(self, ignore_label, name='perplexity', **kwargs):
+ super(Perplexity, self).__init__(name, ignore_label=ignore_label, **kwargs)
self.ignore_label = ignore_label
def update(self, labels, preds):
@@ -326,11 +392,13 @@
# REGRESSION METRICS
####################
+
+@register
class MAE(EvalMetric):
"""Calculate Mean Absolute Error loss"""
- def __init__(self, **kwargs):
- super(MAE, self).__init__('mae', **kwargs)
+ def __init__(self, name='mae', **kwargs):
+ super(MAE, self).__init__(name, **kwargs)
def update(self, labels, preds):
check_label_shapes(labels, preds)
@@ -345,10 +413,12 @@
self.sum_metric += numpy.abs(label - pred).mean()
self.num_inst += 1 # numpy.prod(label.shape)
+
+@register
class MSE(EvalMetric):
"""Calculate Mean Squared Error loss"""
- def __init__(self, **kwargs):
- super(MSE, self).__init__('mse', **kwargs)
+ def __init__(self, name='mse', **kwargs):
+ super(MSE, self).__init__(name, **kwargs)
def update(self, labels, preds):
check_label_shapes(labels, preds)
@@ -363,10 +433,12 @@
self.sum_metric += ((label - pred)**2.0).mean()
self.num_inst += 1 # numpy.prod(label.shape)
+
+@register
class RMSE(EvalMetric):
"""Calculate Root Mean Squred Error loss"""
- def __init__(self, **kwargs):
- super(RMSE, self).__init__('rmse', **kwargs)
+ def __init__(self, name='rmse', **kwargs):
+ super(RMSE, self).__init__(name, **kwargs)
def update(self, labels, preds):
check_label_shapes(labels, preds)
@@ -381,10 +453,13 @@
self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean())
self.num_inst += 1
+
+@register
+@alias('ce')
class CrossEntropy(EvalMetric):
"""Calculate Cross Entropy loss"""
- def __init__(self, eps=1e-8, **kwargs):
- super(CrossEntropy, self).__init__('cross-entropy', **kwargs)
+ def __init__(self, eps=1e-8, name='cross-entropy', **kwargs):
+ super(CrossEntropy, self).__init__(name, eps=eps, **kwargs)
self.eps = eps
def update(self, labels, preds):
@@ -402,6 +477,7 @@
self.num_inst += label.shape[0]
+@register
class Loss(EvalMetric):
"""Dummy metric for directly printing loss"""
def __init__(self, name='loss', **kwargs):
@@ -413,18 +489,21 @@
self.num_inst += pred.size
+@register
class Torch(Loss):
"""Dummy metric for torch criterions"""
def __init__(self, name='torch', **kwargs):
super(Torch, self).__init__(name, **kwargs)
+@register
class Caffe(Loss):
"""Dummy metric for caffe criterions"""
def __init__(self, name='caffe', **kwargs):
super(Caffe, self).__init__(name, **kwargs)
+@register
class CustomMetric(EvalMetric):
"""Custom evaluation metric that takes a NDArray function.
@@ -440,11 +519,16 @@
in outputs for forwarding.
"""
def __init__(self, feval, name=None, allow_extra_outputs=False, **kwargs):
+ if isinstance(feval, string_types):
+ feval = pickle.loads(feval)
if name is None:
name = feval.__name__
if name.find('<') != -1:
name = 'custom(%s)' % name
- super(CustomMetric, self).__init__(name, **kwargs)
+ super(CustomMetric, self).__init__(
+ name, feval=pickle.dumps(feval),
+ allow_extra_outputs=allow_extra_outputs,
+ **kwargs)
self._feval = feval
self._allow_extra_outputs = allow_extra_outputs
@@ -465,6 +549,7 @@
self.sum_metric += reval
self.num_inst += 1
+
# pylint: disable=invalid-name
def np(numpy_feval, name=None, allow_extra_outputs=False):
"""Create a customized metric from numpy function.
@@ -489,40 +574,3 @@
feval.__name__ = numpy_feval.__name__
return CustomMetric(feval, name, allow_extra_outputs)
# pylint: enable=invalid-name
-
-def create(metric, **kwargs):
- """Create an evaluation metric.
-
- Parameters
- ----------
- metric : str or callable
- The name of the metric, or a function
- providing statistics given pred, label NDArray.
- """
-
- if callable(metric):
- return CustomMetric(metric)
- elif isinstance(metric, EvalMetric):
- return metric
- elif isinstance(metric, list):
- composite_metric = CompositeEvalMetric()
- for child_metric in metric:
- composite_metric.add(create(child_metric, **kwargs))
- return composite_metric
-
- metrics = {
- 'acc': Accuracy,
- 'accuracy': Accuracy,
- 'ce': CrossEntropy,
- 'f1': F1,
- 'mae': MAE,
- 'mse': MSE,
- 'rmse': RMSE,
- 'top_k_accuracy': TopKAccuracy
- }
-
- try:
- return metrics[metric.lower()](**kwargs)
- except:
- raise ValueError("Metric must be either callable or in {}".format(
- metrics.keys()))
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 2db7ce6..4af8c76 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -3,6 +3,7 @@
"""MXNet model module"""
from __future__ import absolute_import, print_function
+import os
import time
import logging
import warnings
@@ -370,7 +371,10 @@
- symbol will be loaded from ``prefix-symbol.json``.
- parameters will be loaded from ``prefix-epoch.params``.
"""
- symbol = sym.load('%s-symbol.json' % prefix)
+ if os.path.exists('%s-symbol.json' % prefix):
+ symbol = sym.load('%s-symbol.json' % prefix)
+ else:
+ symbol = None
save_dict = nd.load('%s-%04d.params' % (prefix, epoch))
arg_params = {}
aux_params = {}
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index aaf1fe0..238b053 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -160,20 +160,26 @@
- `predict`: run prediction on a data set and collect outputs
- `score`: run prediction on a data set and evaluate performance
- Examples
- --------
- An example of creating a mxnet module::
- >>> import mxnet as mx
+ To create a module for training classification::
+ data = mx.sym.Variable('data')
+ output = mx.sym.FullyConnected(data, num_hidden=10)
+ label = mx.sym.Variable('label')
+ loss = mx.loss.softmax_cross_entropy_loss(output, label)
+ model = mx.mod.Module(loss, data_names=('data',))
+ model.fit(..., eval_metric=loss.metric)
+ model.score(..., eval_metric=loss.metric)
- >>> data = mx.symbol.Variable('data')
- >>> fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
- >>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
- >>> fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
- >>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
- >>> fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
- >>> out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
+ To create a module for prediction only::
+ data = mx.sym.Variable('data')
+ output = mx.sym.FullyConnected(data, num_hidden=10)
+ model = mx.mod.Module(output, data_names=('data',))
+ model.bind(data_shapes=[('data', (128, 100))], label_shapes=None)
+ model.load_params('save-0001.params')
+ model.predict(...)
- >>> mod = mx.mod.Module(out)
+ You can also load from saved checkpoints::
+ model.save_checkpoint('save', 1)
+ model2 = mx.mod.Module.load('save', 1, context=mx.cpu(0))
"""
def __init__(self, logger=logging):
self.logger = logger
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index de54c16..26e0c8b 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -4,6 +4,8 @@
more `Executor` for data parallelization.
"""
+import os
+import json
import logging
import warnings
@@ -12,7 +14,7 @@
from .. import symbol as _sym
from .. import optimizer as opt
from .. import loss
-from ..base import _Sentinel
+from ..base import _Sentinel, __version__, string_types
from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
@@ -23,8 +25,8 @@
class Module(BaseModule):
- """Module is a basic module that wrap a `Symbol`. It is functionally the same
- as the `FeedForward` model, except under the module API.
+ """Module is a basic module that wrap a `Loss` or `Symbol`.
+ It is functionally the same as the `FeedForward` model.
Parameters
----------
@@ -48,9 +50,14 @@
"""
def __init__(self, symbol, data_names=('data',), label_names=_Sentinel,
logger=logging, context=ctx.cpu(), work_load_list=None,
- fixed_param_names=None, state_names=None):
+ fixed_param_names=None, state_names=None, **kwargs):
super(Module, self).__init__(logger=logger)
+ if isinstance(symbol, string_types):
+ symbol = _sym.load_json(symbol)
+ elif isinstance(symbol, dict):
+ symbol = loss.create(**symbol)
+
if isinstance(context, ctx.Context):
context = [context]
self._context = context
@@ -59,7 +66,19 @@
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list
- if isinstance(symbol, loss.Loss):
+ self._kwargs = kwargs
+ if isinstance(symbol, loss.BaseLoss):
+ self._kwargs['symbol'] = symbol.get_config()
+ else:
+ self._kwargs['symbol'] = symbol.tojson()
+ self._kwargs.update({
+ 'data_names': data_names,
+ 'fixed_param_names': fixed_param_names,
+ 'state_names': state_names,
+ '__type__': 'module',
+ '__version__': __version__})
+
+ if isinstance(symbol, loss.BaseLoss):
self._loss = symbol
self._symbol = _sym.Group([self._loss.output_symbol, self._loss.loss_symbol])
num_output = len(self._loss.output_symbol.list_outputs())
@@ -67,7 +86,7 @@
self._output_range = (0, num_output)
self._loss_range = (num_output, num_output+num_loss)
assert label_names is _Sentinel, \
- "label_names has been deprecated. do not use"
+ "label_names has been deprecated. Do not set."
label_names = self._loss.label_names
else:
self._symbol = symbol
@@ -120,34 +139,38 @@
def load(prefix, epoch, load_optimizer_states=False, **kwargs):
"""Create a model from previously saved checkpoint.
+ For example, use::
+ mod = mx.mod.Module.load('test', 100, context=mx.gpu(0))
+
+ to load from "test-module.json" and "test-0100.params"
+
Parameters
----------
prefix : str
path prefix of saved model files. You should have
- "prefix-symbol.json", "prefix-xxxx.params", and
- optionally "prefix-xxxx.states", where xxxx is the
- epoch number.
+ "prefix-symbol.json"/"prefix-module.json",
+ "prefix-xxxx.params", and optionally "prefix-xxxx.states",
+ where xxxx is the epoch number.
epoch : int
epoch to load.
load_optimizer_states : bool
whether to load optimizer states. Checkpoint needs
to have been made with save_optimizer_states=True.
- data_names : list of str
- Default is `('data')` for a typical model used in image classification.
- label_names : list of str
- Default is `('softmax_label')` for a typical model used in image
- classification.
- logger : Logger
- Default is `logging`.
context : Context or list of Context
Default is `cpu()`.
work_load_list : list of number
Default `None`, indicating uniform workload.
- fixed_param_names: list of str
- Default `None`, indicating no network parameters are fixed.
+ logger : Logger
+ Default is `logging`.
"""
sym, args, auxs = load_checkpoint(prefix, epoch)
- mod = Module(symbol=sym, **kwargs)
+ if os.path.exists('%s-module.json'%prefix):
+ config = json.loads(open('%s-module.json'%prefix).read())
+ config.update(kwargs)
+ mod = Module(**config)
+ else:
+ mod = Module(sym, **kwargs)
+
mod._arg_params = args
mod._aux_params = auxs
mod.params_initialized = True
@@ -157,7 +180,11 @@
def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
"""Save current progress to checkpoint.
- Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
+ Use mx.callback.module_checkpoint as
+ epoch_end_callback to save during training.
+
+ Outputs 'prefix-module.json', 'prefix-symbol.json',
+ 'prefix-(epoch).params', and optionally 'prefix-(epoch).states'.
Parameters
----------
@@ -168,7 +195,13 @@
save_optimizer_states : bool
Whether to save optimizer states for continue training
"""
- self._symbol.save('%s-symbol.json'%prefix)
+ if self._loss is not None:
+ self._loss.output_symbol.save('%s-symbol.json'%prefix)
+ else:
+ self._symbol.save('%s-symbol.json'%prefix)
+ with open('%s-module.json'%prefix, 'w') as fout:
+ json.dump(self._kwargs, fout)
+
param_name = '%s-%04d.params' % (prefix, epoch)
self.save_params(param_name)
logging.info('Saved checkpoint to \"%s\"', param_name)
diff --git a/python/mxnet/registry.py b/python/mxnet/registry.py
new file mode 100644
index 0000000..4d1cf0f
--- /dev/null
+++ b/python/mxnet/registry.py
@@ -0,0 +1,127 @@
+# coding: utf-8
+# pylint: disable=no-member
+
+"""Registry for serializable objects."""
+from __future__ import absolute_import
+
+import warnings
+
+from .base import string_types
+
+_REGISTRY = {}
+
+
+def get_register_func(base_class, nickname):
+ """Get registrator function.
+
+ Parameters
+ ----------
+ base_class : type
+ base class for classes that will be reigstered
+ nickname : str
+ nickname of base_class for logging
+
+ Returns
+ -------
+ a registrator function
+ """
+ if base_class not in _REGISTRY:
+ _REGISTRY[base_class] = {}
+ registry = _REGISTRY[base_class]
+
+ def register(klass, name=None):
+ """Register functions"""
+ assert issubclass(klass, base_class), \
+ "Can only register subclass of %s"%base_class.__name__
+ if name is None:
+ name = klass.__name__.lower()
+ if name in registry:
+ warnings.warn(
+ "\033[91mNew %s %s.%s registered with name %s is"
+ "overriding existing %s %s.%s\033[0m"%(
+ nickname, klass.__module__, klass.__name__, name,
+ nickname, registry[name].__module__, registry[name].__name__),
+ UserWarning, stacklevel=2)
+ registry[name] = klass
+ return klass
+
+ register.__doc__ = "Register %s to the %s factory"%(nickname, nickname)
+ return register
+
+
+def get_alias_func(base_class, nickname):
+ """Get registrator function that allow aliases.
+
+ Parameters
+ ----------
+ base_class : type
+ base class for classes that will be reigstered
+ nickname : str
+ nickname of base_class for logging
+
+ Returns
+ -------
+ a registrator function
+ """
+ register = get_register_func(base_class, nickname)
+
+ def alias(*aliases):
+ """alias registrator"""
+ def reg(klass):
+ """registrator function"""
+ for name in aliases:
+ register(klass, name)
+ return klass
+ return reg
+ return alias
+
+
+def get_create_func(base_class, nickname):
+ """Get creator function
+
+ Parameters
+ ----------
+ base_class : type
+ base class for classes that will be reigstered
+ nickname : str
+ nickname of base_class for logging
+
+ Returns
+ -------
+ a creator function
+ """
+ if base_class not in _REGISTRY:
+ _REGISTRY[base_class] = {}
+ registry = _REGISTRY[base_class]
+
+ def create(*args, **kwargs):
+ """Create instance from config"""
+ if len(args):
+ name = args[0]
+ args = args[1:]
+ else:
+ name = kwargs.pop(nickname)
+
+ if isinstance(name, base_class):
+ assert len(args) == 0 and len(kwargs) == 0, \
+ "%s is already an instance. Additional arguments are invalid"%(nickname)
+ return name
+
+ assert isinstance(name, string_types), "%s must be of string type"%nickname
+ name = name.lower()
+ assert name in registry, \
+ "%s is not registered. Please register with %s.register first"%(
+ str(name), nickname)
+ return registry[name](*args, **kwargs)
+
+ create.__doc__ = """Create a %s instance from config.
+
+Parameters
+----------
+%s : str or %s instance
+ class name of desired instance. If is a instance,
+ it will be returned directly.
+**kwargs : dict
+ arguments to be passed to constructor"""%(nickname, nickname, base_class.__name__)
+
+ return create
diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py
index 714a656..4ec17b7 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -21,7 +21,7 @@
data_iter = mx.io.NDArrayIter(data, label, batch_size=10, label_name='label')
output = get_net(nclass)
l = mx.symbol.Variable('label')
- loss = mx.loss.cross_entropy_loss(output, l)
+ loss = mx.loss.softmax_cross_entropy_loss(output, l)
mod = mx.mod.Module(loss)
mod.fit(data_iter, eval_metric=loss.metric, num_epoch=200, optimizer_params={'learning_rate': 1.})
assert mod.score(data_iter, loss.metric)[0][1] == 1.0
@@ -87,7 +87,7 @@
output = get_net(nclass)
l = mx.symbol.Variable('label')
w = mx.symbol.Variable('w')
- loss = mx.loss.cross_entropy_loss(output, l, sample_weight=w)
+ loss = mx.loss.softmax_cross_entropy_loss(output, l, sample_weight=w)
mod = mx.mod.Module(loss)
mod.fit(data_iter, eval_metric=loss.metric, num_epoch=200,
optimizer_params={'learning_rate': 1.})
@@ -111,7 +111,7 @@
output2 = mx.symbol.FullyConnected(act3, name='output2', num_hidden=1)
l1 = mx.symbol.Variable('label1')
l2 = mx.symbol.Variable('label2')
- loss1 = mx.loss.cross_entropy_loss(output1, l1)
+ loss1 = mx.loss.softmax_cross_entropy_loss(output1, l1)
loss2 = mx.loss.l2_loss(output2, l2)
loss = mx.loss.multi_loss([loss1, loss2])
mod = mx.mod.Module(loss)
@@ -124,7 +124,36 @@
assert score[2][1] < 0.2
+def test_saveload():
+ nclass = 10
+ output = get_net(nclass)
+ l = mx.symbol.Variable('label')
+ loss = mx.loss.softmax_cross_entropy_loss(output, l)
+ assert mx.loss.create(**loss.get_config()).get_config() == loss.get_config()
+
+
+def test_saveload2():
+ mx.random.seed(1234)
+ np.random.seed(1234)
+ nclass = 10
+ N = 20
+ data = mx.random.uniform(-1, 1, shape=(N, nclass))
+ label = mx.nd.array(np.random.randint(0, nclass, size=(N,)), dtype='int32')
+ data_iter = mx.io.NDArrayIter(data, label, batch_size=10, label_name='label')
+ output = get_net(nclass)
+ l = mx.symbol.Variable('label')
+ loss = mx.loss.softmax_cross_entropy_loss(output, l)
+ mod = mx.mod.Module(loss)
+ mod.fit(data_iter, eval_metric=loss.metric, num_epoch=100, optimizer_params={'learning_rate': 1.})
+ mod.save_checkpoint('test', 100, save_optimizer_states=True)
+ mod = mx.mod.Module.load('test', 100, load_optimizer_states=True)
+ mod.fit(data_iter, eval_metric=mod._loss.metric, num_epoch=100, optimizer_params={'learning_rate': 1.})
+ assert mod.score(data_iter, loss.metric)[0][1] == 1.0
+
+
if __name__ == '__main__':
+ test_saveload()
+ test_saveload2()
test_multi_loss()
test_sample_weight_loss()
test_custom_loss()