save load for module, loss, metric (#5418)

diff --git a/docs/api/python/module.md b/docs/api/python/module.md
index 5a0198d..1f2553d 100644
--- a/docs/api/python/module.md
+++ b/docs/api/python/module.md
@@ -8,28 +8,47 @@
 
 The module API, defined in the `module` (or simply `mod`) package, provides an
 intermediate and high-level interface for performing computation with a
-`Symbol`. One can roughly think a module is a machine which can execute a
-program defined by a `Symbol`.
+`Symbol` or `Loss`. One can roughly think a module is a machine which can execute a
+program defined by a `Symbol` or `Loss`.
 
-The class `module.Module` is a commonly used module, which accepts a `Symbol` as
+The class `module.Module` is a commonly used module, which accepts a `Symbol` or `Loss` as
 the input:
 
 ```python
 data = mx.symbol.Variable('data')
+label = mx.sym.Variable('label')
 fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
 act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
 fc2  = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10)
-out  = mx.symbol.SoftmaxOutput(fc2, name = 'softmax')
-mod = mx.mod.Module(out)  # create a module by given a Symbol
+
+loss = mx.loss.softmax_cross_entropy_loss(fc2, label)
+mod = mx.mod.Module(loss, data_names=('data',))
 ```
 
+Alternatively, if you only want to do prediction or want to compute loss manually outside
+of module and feed gradient back using Module.backward(out_grads=...), you can also directly
+feed a `Symbol` into module:
+
+```python
+data = mx.symbol.Variable('data')
+label = mx.sym.Variable('label')
+fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
+act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
+fc2  = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10)
+
+mod = mx.mod.Module(fc2, data_names=('data',))
+```
+
+
 Assume there is a valid MXNet data iterator `data`. We can initialize the
 module:
 
 ```python
+# allocate memory by given input shapes
 mod.bind(data_shapes=data.provide_data,
-         label_shapes=data.provide_label)  # create memory by given input shapes
-mod.init_params()  # initial parameters with the default random initializer
+         label_shapes=data.provide_label)
+# initial parameters with uniform distribution in [-0.05, 0.05]
+mod.init_params(mx.init.Uniform(0.05))
 ```
 
 Now the module is able to compute. We can call high-level API to train and
@@ -43,9 +62,10 @@
 or use intermediate APIs to perform step-by-step computations
 
 ```python
-mod.forward(data_batch)  # forward on the provided data batch
-mod.backward()  # backward to calculate the gradients
-mod.update()  # update parameters using the default optimizer
+for data_batch in data:
+    mod.forward(data_batch)  # forward on the provided data batch
+    mod.backward()  # backward to calculate the gradients
+    mod.update()  # update parameters using the default optimizer
 ```
 
 A detailed tutorial is available at [http://mxnet.io/tutorials/python/module.html](http://mxnet.io/tutorials/python/module.html).
diff --git a/python/mxnet/loss.py b/python/mxnet/loss.py
index 8047761..d5ec6ad 100644
--- a/python/mxnet/loss.py
+++ b/python/mxnet/loss.py
@@ -3,8 +3,9 @@
 """ losses for training neural networks """
 from __future__ import absolute_import
 
-from .base import numeric_types, string_types
+from .base import numeric_types, string_types, __version__
 from . import symbol
+from . import registry
 from . import metric as _metric
 
 
@@ -58,16 +59,16 @@
     return metric
 
 
-class Loss(object):
+class BaseLoss(object):
     """Base class for all loss layers.
 
     Parameters
     ----------
-    loss : Symbol
+    losses : Symbol
         a symbol whose output is the loss. Can be a scalar value
         or an array. If loss is an array, the sum of its elements
         will be the final loss.
-    output : Symbol
+    outputs : Symbol
         output of the model when predicting.
     label_names : list of str
         names of label variables. labels are used for training
@@ -84,22 +85,47 @@
 
     Returns
     -------
-    loss : Loss
+    loss : BaseLoss
         created loss
     """
-    def __init__(self, loss, output, label_names, name, metric=None,
-                 output_head_grad=False, loss_head_grad=False):
+    def __init__(self, losses, outputs, label_names, name, metric=None,
+                 output_head_grad=False, loss_head_grad=False, **kwargs):
+        if losses is None:
+            sym = list(symbol.load_json(kwargs['__group_sym__']))
+            num = kwargs['__num_losses__']
+            losses = symbol.Group(sym[:num])
+            outputs = symbol.Group(sym[num:])
+            if metric is not None:
+                metric = _metric.create(**metric)
+
+        losses = symbol.Group(list(losses))
+        outputs = symbol.Group(list(outputs))
+        self._kwargs = kwargs
+        self._kwargs.update({
+            'loss': self.__class__.__name__,
+            'losses': None,
+            'outputs': None,
+            'label_names': label_names,
+            'name': name,
+            'metric': metric.get_config() if metric is not None else None,
+            'output_head_grad': output_head_grad,
+            'loss_head_grad': loss_head_grad,
+            '__group_sym__': symbol.Group(list(losses)+list(outputs)).tojson(),
+            '__num_losses__': len(list(losses)),
+            '__type__': 'loss',
+            '__version__': __version__})
+
         if not loss_head_grad:
             self._loss_symbol = symbol.Group([symbol.make_loss(x, name=x.name+'_loss')
-                                              for x in loss])
+                                              for x in losses])
         else:
-            self._loss_symbol = loss
+            self._loss_symbol = losses
 
         if not output_head_grad:
             self._output_symbol = symbol.Group([symbol.stop_gradient(x, name=x.name+'_out')
-                                                for x in output])
+                                                for x in outputs])
         else:
-            self._output_symbol = output
+            self._output_symbol = outputs
 
         self._label_names = list(label_names) if label_names else []
         self._name = name
@@ -136,6 +162,17 @@
         """Metric for evaluation"""
         return self._metric
 
+    def get_config(self):
+        """get configs for serialization"""
+        return self._kwargs.copy()
+
+
+# pylint: disable=invalid-name
+register = registry.get_register_func(BaseLoss, 'loss')
+create = registry.get_create_func(BaseLoss, 'loss')
+register(BaseLoss)
+# pylint: enable=invalid-name
+
 
 def custom_loss(loss, output, label_names, extra_outputs=(),
                 weight=None, sample_weight=None, name='loss',
@@ -182,7 +219,7 @@
 
     Returns
     -------
-    loss : Loss
+    loss : BaseLoss
         created loss
     """
     label_names = list(label_names)
@@ -193,7 +230,7 @@
                         if i not in loss.list_arguments()]
     loss = _apply_weight(loss, weight=weight, sample_weight=sample_weight)
     loss._set_attr(name=name)
-    return Loss(loss, output, label_names, name, **kwargs)
+    return BaseLoss(loss, output, label_names, name, **kwargs)
 
 
 def multi_loss(losses, extra_outputs=(), name='multi'):
@@ -202,7 +239,7 @@
 
     Parameters
     ----------
-    losses : list of Loss
+    losses : list of BaseLoss
         a list of individual losses with no extra outputs.
     extra_outputs : list of Symbol
         extra outputs for predition but not used for evaluating
@@ -213,7 +250,7 @@
 
     Returns
     -------
-    loss : Loss
+    loss : BaseLoss
         created loss
     """
     loss = symbol.Group(sum([list(i.loss_symbol) for i in losses], []))
@@ -223,8 +260,8 @@
         for name in i.label_names:
             if name not in label_names:
                 label_names.append(name)
-    ret = Loss(loss, output, label_names, name,
-               output_head_grad=True, loss_head_grad=True)
+    ret = BaseLoss(loss, output, label_names, name,
+                   output_head_grad=True, loss_head_grad=True)
     del ret.metric.metrics[:]
     for i in losses:
         ret.metric.add(i.metric)
@@ -272,7 +309,7 @@
 
     Returns
     -------
-    loss : Loss
+    loss : BaseLoss
         created loss
     """
     metric = _parse_metric(metric, output, label)
@@ -286,7 +323,7 @@
 
     label_names = [x for x in loss.list_arguments()
                    if x not in output.list_arguments()]
-    return Loss(loss, outputs, label_names, name, metric=metric, **kwargs)
+    return BaseLoss(loss, outputs, label_names, name, metric=metric, **kwargs)
 
 
 def l1_loss(output, label, extra_outputs=(), name='l1',
@@ -329,7 +366,7 @@
 
     Returns
     -------
-    loss : Loss
+    loss : BaseLoss
         created loss
     """
     metric = _parse_metric(metric, output, label)
@@ -343,12 +380,12 @@
 
     label_names = [x for x in loss.list_arguments()
                    if x not in output.list_arguments()]
-    return Loss(loss, outputs, label_names, name, metric=metric, **kwargs)
+    return BaseLoss(loss, outputs, label_names, name, metric=metric, **kwargs)
 
 
-def cross_entropy_loss(output, label, sparse_label=True, axis=1,
-                       extra_outputs=(), name='ce', weight=None,
-                       sample_weight=None, metric='acc', **kwargs):
+def softmax_cross_entropy_loss(output, label, sparse_label=True, axis=1,
+                               extra_outputs=(), name='ce', weight=None,
+                               sample_weight=None, metric='acc', **kwargs):
     """Compute the softmax cross entropy loss.
 
     If sparse_label is True, label should contain integer category indicators:
@@ -395,7 +432,7 @@
 
     Returns
     -------
-    loss : Loss
+    loss : BaseLoss
         created loss
     """
     metric = _parse_metric(metric, output, label)
@@ -418,4 +455,4 @@
 
     label_names = [x for x in loss.list_arguments()
                    if x not in output.list_arguments()]
-    return Loss(loss, outputs, label_names, name, metric=metric, **kwargs)
+    return BaseLoss(loss, outputs, label_names, name, metric=metric, **kwargs)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 3bd26b7..6b7f87a 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -7,9 +7,12 @@
 from collections import OrderedDict
 
 import numpy
+import pickle
 
-from .base import numeric_types
+from . import base
+from .base import numeric_types, string_types
 from . import ndarray
+from . import registry
 
 
 def check_label_shapes(labels, preds, shape=0):
@@ -24,25 +27,36 @@
         raise ValueError("Shape of labels {} does not match shape of "
                          "predictions {}".format(label_shape, pred_shape))
 
+
 class EvalMetric(object):
     """Base class of all evaluation metrics."""
 
-    def __init__(self, name, num=None, output_names=None, label_names=None):
+    def __init__(self, name, num=None, output_names=None,
+                 label_names=None, **kwargs):
         self.name = name
         self.num = num
         self.output_names = output_names
         self.label_names = label_names
+        self._kwargs = kwargs
         self.reset()
 
     def __str__(self):
         return "EvalMetric: {}".format(dict(self.get_name_value()))
 
-    def __getstate__(self):
-        return self.__dict__.copy()
-
-    def __setstate__(self, state):
-        self.__dict__.update(state)
-        self.reset()
+    def get_config(self):
+        """Save configurations of metric. Can be recreated
+        from configs with metric.create(**config)
+        """
+        config = self._kwargs.copy()
+        config.update({
+            'metric': self.__class__.__name__,
+            'name': self.name,
+            'num': self.num,
+            'output_names': self.output_names,
+            'label_names': self.label_names,
+            '__type__': 'metric',
+            '__version__': base.__version__})
+        return config
 
     def update_dict(self, label, pred):
         """Update the internal evaluation with named label and pred
@@ -119,7 +133,44 @@
             value = [value]
         return list(zip(name, value))
 
+# pylint: disable=invalid-name
+register = registry.get_register_func(EvalMetric, 'metric')
+alias = registry.get_alias_func(EvalMetric, 'metric')
+_create = registry.get_create_func(EvalMetric, 'metric')
+# pylint: enable=invalid-name
 
+
+def create(metric, *args, **kwargs):
+    """Create an evaluation metric.
+
+    Parameters
+    ----------
+    metric : str or callable
+        The name of the metric, or a function
+        providing statistics given pred, label NDArray.
+    *args : list
+        additional arguments to metric constructor
+    **kwargs : dict
+        additional arguments to metric constructor
+
+    Returns
+    -------
+    created metric
+    """
+
+    if callable(metric):
+        return CustomMetric(metric, *args, **kwargs)
+    elif isinstance(metric, list):
+        composite_metric = CompositeEvalMetric()
+        for child_metric in metric:
+            composite_metric.add(create(child_metric, *args, **kwargs))
+        return composite_metric
+
+    return _create(metric, *args, **kwargs)
+
+
+@register
+@alias('composite')
 class CompositeEvalMetric(EvalMetric):
     """Manage multiple evaluation metrics."""
 
@@ -176,10 +227,19 @@
             values.extend(value)
         return (names, values)
 
+    def get_config(self):
+        config = super(CompositeEvalMetric, self).get_config()
+        config.update({'metrics': [i.get_config() for i in self.metrics]})
+        return config
+
+
 ########################
 # CLASSIFICATION METRICS
 ########################
 
+
+@register
+@alias('acc')
 class Accuracy(EvalMetric):
     """Calculate accuracy
 
@@ -188,8 +248,8 @@
     axis : int, default=1
         The axis that represents classes
     """
-    def __init__(self, axis=1, **kwargs):
-        super(Accuracy, self).__init__('accuracy', **kwargs)
+    def __init__(self, axis=1, name='accuracy', **kwargs):
+        super(Accuracy, self).__init__(name, axis=axis, **kwargs)
         self.axis = axis
 
     def update(self, labels, preds):
@@ -206,11 +266,14 @@
             self.sum_metric += (pred_label.flat == label.flat).sum()
             self.num_inst += len(pred_label.flat)
 
+
+@register
+@alias('top_k_accuracy', 'top_k_acc')
 class TopKAccuracy(EvalMetric):
     """Calculate top k predictions accuracy"""
 
-    def __init__(self, top_k=1, **kwargs):
-        super(TopKAccuracy, self).__init__('top_k_accuracy', **kwargs)
+    def __init__(self, top_k=1, name='top_k_accuracy', **kwargs):
+        super(TopKAccuracy, self).__init__(name, top_k=top_k, **kwargs)
         self.top_k = top_k
         assert(self.top_k > 1), 'Please use Accuracy if top_k is no more than 1'
         self.name += '_%d' % self.top_k
@@ -234,11 +297,13 @@
                     self.sum_metric += (pred_label[:, num_classes - 1 - j].flat == label.flat).sum()
             self.num_inst += num_samples
 
+
+@register
 class F1(EvalMetric):
     """Calculate the F1 score of a binary classification problem."""
 
-    def __init__(self, **kwargs):
-        super(F1, self).__init__('f1', **kwargs)
+    def __init__(self, name='f1', **kwargs):
+        super(F1, self).__init__(name, **kwargs)
 
     def update(self, labels, preds):
         check_label_shapes(labels, preds)
@@ -281,6 +346,7 @@
             self.num_inst += 1
 
 
+@register
 class Perplexity(EvalMetric):
     """Calculate perplexity
 
@@ -291,8 +357,8 @@
         counting. usually should be -1. Include
         all entries if None.
     """
-    def __init__(self, ignore_label, **kwargs):
-        super(Perplexity, self).__init__('Perplexity', **kwargs)
+    def __init__(self, ignore_label, name='perplexity', **kwargs):
+        super(Perplexity, self).__init__(name, ignore_label=ignore_label, **kwargs)
         self.ignore_label = ignore_label
 
     def update(self, labels, preds):
@@ -326,11 +392,13 @@
 # REGRESSION METRICS
 ####################
 
+
+@register
 class MAE(EvalMetric):
     """Calculate Mean Absolute Error loss"""
 
-    def __init__(self, **kwargs):
-        super(MAE, self).__init__('mae', **kwargs)
+    def __init__(self, name='mae', **kwargs):
+        super(MAE, self).__init__(name, **kwargs)
 
     def update(self, labels, preds):
         check_label_shapes(labels, preds)
@@ -345,10 +413,12 @@
             self.sum_metric += numpy.abs(label - pred).mean()
             self.num_inst += 1 # numpy.prod(label.shape)
 
+
+@register
 class MSE(EvalMetric):
     """Calculate Mean Squared Error loss"""
-    def __init__(self, **kwargs):
-        super(MSE, self).__init__('mse', **kwargs)
+    def __init__(self, name='mse', **kwargs):
+        super(MSE, self).__init__(name, **kwargs)
 
     def update(self, labels, preds):
         check_label_shapes(labels, preds)
@@ -363,10 +433,12 @@
             self.sum_metric += ((label - pred)**2.0).mean()
             self.num_inst += 1 # numpy.prod(label.shape)
 
+
+@register
 class RMSE(EvalMetric):
     """Calculate Root Mean Squred Error loss"""
-    def __init__(self, **kwargs):
-        super(RMSE, self).__init__('rmse', **kwargs)
+    def __init__(self, name='rmse', **kwargs):
+        super(RMSE, self).__init__(name, **kwargs)
 
     def update(self, labels, preds):
         check_label_shapes(labels, preds)
@@ -381,10 +453,13 @@
             self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean())
             self.num_inst += 1
 
+
+@register
+@alias('ce')
 class CrossEntropy(EvalMetric):
     """Calculate Cross Entropy loss"""
-    def __init__(self, eps=1e-8, **kwargs):
-        super(CrossEntropy, self).__init__('cross-entropy', **kwargs)
+    def __init__(self, eps=1e-8, name='cross-entropy', **kwargs):
+        super(CrossEntropy, self).__init__(name, eps=eps, **kwargs)
         self.eps = eps
 
     def update(self, labels, preds):
@@ -402,6 +477,7 @@
             self.num_inst += label.shape[0]
 
 
+@register
 class Loss(EvalMetric):
     """Dummy metric for directly printing loss"""
     def __init__(self, name='loss', **kwargs):
@@ -413,18 +489,21 @@
             self.num_inst += pred.size
 
 
+@register
 class Torch(Loss):
     """Dummy metric for torch criterions"""
     def __init__(self, name='torch', **kwargs):
         super(Torch, self).__init__(name, **kwargs)
 
 
+@register
 class Caffe(Loss):
     """Dummy metric for caffe criterions"""
     def __init__(self, name='caffe', **kwargs):
         super(Caffe, self).__init__(name, **kwargs)
 
 
+@register
 class CustomMetric(EvalMetric):
     """Custom evaluation metric that takes a NDArray function.
 
@@ -440,11 +519,16 @@
         in outputs for forwarding.
     """
     def __init__(self, feval, name=None, allow_extra_outputs=False, **kwargs):
+        if isinstance(feval, string_types):
+            feval = pickle.loads(feval)
         if name is None:
             name = feval.__name__
             if name.find('<') != -1:
                 name = 'custom(%s)' % name
-        super(CustomMetric, self).__init__(name, **kwargs)
+        super(CustomMetric, self).__init__(
+            name, feval=pickle.dumps(feval),
+            allow_extra_outputs=allow_extra_outputs,
+            **kwargs)
         self._feval = feval
         self._allow_extra_outputs = allow_extra_outputs
 
@@ -465,6 +549,7 @@
                 self.sum_metric += reval
                 self.num_inst += 1
 
+
 # pylint: disable=invalid-name
 def np(numpy_feval, name=None, allow_extra_outputs=False):
     """Create a customized metric from numpy function.
@@ -489,40 +574,3 @@
     feval.__name__ = numpy_feval.__name__
     return CustomMetric(feval, name, allow_extra_outputs)
 # pylint: enable=invalid-name
-
-def create(metric, **kwargs):
-    """Create an evaluation metric.
-
-    Parameters
-    ----------
-    metric : str or callable
-        The name of the metric, or a function
-        providing statistics given pred, label NDArray.
-    """
-
-    if callable(metric):
-        return CustomMetric(metric)
-    elif isinstance(metric, EvalMetric):
-        return metric
-    elif isinstance(metric, list):
-        composite_metric = CompositeEvalMetric()
-        for child_metric in metric:
-            composite_metric.add(create(child_metric, **kwargs))
-        return composite_metric
-
-    metrics = {
-        'acc': Accuracy,
-        'accuracy': Accuracy,
-        'ce': CrossEntropy,
-        'f1': F1,
-        'mae': MAE,
-        'mse': MSE,
-        'rmse': RMSE,
-        'top_k_accuracy': TopKAccuracy
-    }
-
-    try:
-        return metrics[metric.lower()](**kwargs)
-    except:
-        raise ValueError("Metric must be either callable or in {}".format(
-            metrics.keys()))
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 2db7ce6..4af8c76 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -3,6 +3,7 @@
 """MXNet model module"""
 from __future__ import absolute_import, print_function
 
+import os
 import time
 import logging
 import warnings
@@ -370,7 +371,10 @@
     - symbol will be loaded from ``prefix-symbol.json``.
     - parameters will be loaded from ``prefix-epoch.params``.
     """
-    symbol = sym.load('%s-symbol.json' % prefix)
+    if os.path.exists('%s-symbol.json' % prefix):
+        symbol = sym.load('%s-symbol.json' % prefix)
+    else:
+        symbol = None
     save_dict = nd.load('%s-%04d.params' % (prefix, epoch))
     arg_params = {}
     aux_params = {}
diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py
index aaf1fe0..238b053 100644
--- a/python/mxnet/module/base_module.py
+++ b/python/mxnet/module/base_module.py
@@ -160,20 +160,26 @@
     - `predict`: run prediction on a data set and collect outputs
     - `score`: run prediction on a data set and evaluate performance
 
-    Examples
-    --------
-    An example of creating a mxnet module::
-        >>> import mxnet as mx
+    To create a module for training classification::
+        data = mx.sym.Variable('data')
+        output = mx.sym.FullyConnected(data, num_hidden=10)
+        label = mx.sym.Variable('label')
+        loss = mx.loss.softmax_cross_entropy_loss(output, label)
+        model = mx.mod.Module(loss, data_names=('data',))
+        model.fit(..., eval_metric=loss.metric)
+        model.score(..., eval_metric=loss.metric)
 
-        >>> data = mx.symbol.Variable('data')
-        >>> fc1  = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
-        >>> act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
-        >>> fc2  = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
-        >>> act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
-        >>> fc3  = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
-        >>> out  = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
+    To create a module for prediction only::
+        data = mx.sym.Variable('data')
+        output = mx.sym.FullyConnected(data, num_hidden=10)
+        model = mx.mod.Module(output, data_names=('data',))
+        model.bind(data_shapes=[('data', (128, 100))], label_shapes=None)
+        model.load_params('save-0001.params')
+        model.predict(...)
 
-        >>> mod = mx.mod.Module(out)
+    You can also load from saved checkpoints::
+        model.save_checkpoint('save', 1)
+        model2 = mx.mod.Module.load('save', 1, context=mx.cpu(0))
     """
     def __init__(self, logger=logging):
         self.logger = logger
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index de54c16..26e0c8b 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -4,6 +4,8 @@
 more `Executor` for data parallelization.
 """
 
+import os
+import json
 import logging
 import warnings
 
@@ -12,7 +14,7 @@
 from .. import symbol as _sym
 from .. import optimizer as opt
 from .. import loss
-from ..base import _Sentinel
+from ..base import _Sentinel, __version__, string_types
 
 from .executor_group import DataParallelExecutorGroup
 from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
@@ -23,8 +25,8 @@
 
 
 class Module(BaseModule):
-    """Module is a basic module that wrap a `Symbol`. It is functionally the same
-    as the `FeedForward` model, except under the module API.
+    """Module is a basic module that wrap a `Loss` or `Symbol`.
+    It is functionally the same as the `FeedForward` model.
 
     Parameters
     ----------
@@ -48,9 +50,14 @@
     """
     def __init__(self, symbol, data_names=('data',), label_names=_Sentinel,
                  logger=logging, context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None):
+                 fixed_param_names=None, state_names=None, **kwargs):
         super(Module, self).__init__(logger=logger)
 
+        if isinstance(symbol, string_types):
+            symbol = _sym.load_json(symbol)
+        elif isinstance(symbol, dict):
+            symbol = loss.create(**symbol)
+
         if isinstance(context, ctx.Context):
             context = [context]
         self._context = context
@@ -59,7 +66,19 @@
         assert len(work_load_list) == len(self._context)
         self._work_load_list = work_load_list
 
-        if isinstance(symbol, loss.Loss):
+        self._kwargs = kwargs
+        if isinstance(symbol, loss.BaseLoss):
+            self._kwargs['symbol'] = symbol.get_config()
+        else:
+            self._kwargs['symbol'] = symbol.tojson()
+        self._kwargs.update({
+            'data_names': data_names,
+            'fixed_param_names': fixed_param_names,
+            'state_names': state_names,
+            '__type__': 'module',
+            '__version__': __version__})
+
+        if isinstance(symbol, loss.BaseLoss):
             self._loss = symbol
             self._symbol = _sym.Group([self._loss.output_symbol, self._loss.loss_symbol])
             num_output = len(self._loss.output_symbol.list_outputs())
@@ -67,7 +86,7 @@
             self._output_range = (0, num_output)
             self._loss_range = (num_output, num_output+num_loss)
             assert label_names is _Sentinel, \
-                "label_names has been deprecated. do not use"
+                "label_names has been deprecated. Do not set."
             label_names = self._loss.label_names
         else:
             self._symbol = symbol
@@ -120,34 +139,38 @@
     def load(prefix, epoch, load_optimizer_states=False, **kwargs):
         """Create a model from previously saved checkpoint.
 
+        For example, use::
+            mod = mx.mod.Module.load('test', 100, context=mx.gpu(0))
+
+        to load from "test-module.json" and "test-0100.params"
+
         Parameters
         ----------
         prefix : str
             path prefix of saved model files. You should have
-            "prefix-symbol.json", "prefix-xxxx.params", and
-            optionally "prefix-xxxx.states", where xxxx is the
-            epoch number.
+            "prefix-symbol.json"/"prefix-module.json",
+            "prefix-xxxx.params", and optionally "prefix-xxxx.states",
+            where xxxx is the epoch number.
         epoch : int
             epoch to load.
         load_optimizer_states : bool
             whether to load optimizer states. Checkpoint needs
             to have been made with save_optimizer_states=True.
-        data_names : list of str
-            Default is `('data')` for a typical model used in image classification.
-        label_names : list of str
-            Default is `('softmax_label')` for a typical model used in image
-            classification.
-        logger : Logger
-            Default is `logging`.
         context : Context or list of Context
             Default is `cpu()`.
         work_load_list : list of number
             Default `None`, indicating uniform workload.
-        fixed_param_names: list of str
-            Default `None`, indicating no network parameters are fixed.
+        logger : Logger
+            Default is `logging`.
         """
         sym, args, auxs = load_checkpoint(prefix, epoch)
-        mod = Module(symbol=sym, **kwargs)
+        if os.path.exists('%s-module.json'%prefix):
+            config = json.loads(open('%s-module.json'%prefix).read())
+            config.update(kwargs)
+            mod = Module(**config)
+        else:
+            mod = Module(sym, **kwargs)
+
         mod._arg_params = args
         mod._aux_params = auxs
         mod.params_initialized = True
@@ -157,7 +180,11 @@
 
     def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
         """Save current progress to checkpoint.
-        Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
+        Use mx.callback.module_checkpoint as
+        epoch_end_callback to save during training.
+
+        Outputs 'prefix-module.json', 'prefix-symbol.json',
+        'prefix-(epoch).params', and optionally 'prefix-(epoch).states'.
 
         Parameters
         ----------
@@ -168,7 +195,13 @@
         save_optimizer_states : bool
             Whether to save optimizer states for continue training
         """
-        self._symbol.save('%s-symbol.json'%prefix)
+        if self._loss is not None:
+            self._loss.output_symbol.save('%s-symbol.json'%prefix)
+        else:
+            self._symbol.save('%s-symbol.json'%prefix)
+        with open('%s-module.json'%prefix, 'w') as fout:
+            json.dump(self._kwargs, fout)
+
         param_name = '%s-%04d.params' % (prefix, epoch)
         self.save_params(param_name)
         logging.info('Saved checkpoint to \"%s\"', param_name)
diff --git a/python/mxnet/registry.py b/python/mxnet/registry.py
new file mode 100644
index 0000000..4d1cf0f
--- /dev/null
+++ b/python/mxnet/registry.py
@@ -0,0 +1,127 @@
+# coding: utf-8
+# pylint: disable=no-member
+
+"""Registry for serializable objects."""
+from __future__ import absolute_import
+
+import warnings
+
+from .base import string_types
+
+_REGISTRY = {}
+
+
+def get_register_func(base_class, nickname):
+    """Get registrator function.
+
+    Parameters
+    ----------
+    base_class : type
+        base class for classes that will be reigstered
+    nickname : str
+        nickname of base_class for logging
+
+    Returns
+    -------
+    a registrator function
+    """
+    if base_class not in _REGISTRY:
+        _REGISTRY[base_class] = {}
+    registry = _REGISTRY[base_class]
+
+    def register(klass, name=None):
+        """Register functions"""
+        assert issubclass(klass, base_class), \
+            "Can only register subclass of %s"%base_class.__name__
+        if name is None:
+            name = klass.__name__.lower()
+        if name in registry:
+            warnings.warn(
+                "\033[91mNew %s %s.%s registered with name %s is"
+                "overriding existing %s %s.%s\033[0m"%(
+                    nickname, klass.__module__, klass.__name__, name,
+                    nickname, registry[name].__module__, registry[name].__name__),
+                UserWarning, stacklevel=2)
+        registry[name] = klass
+        return klass
+
+    register.__doc__ = "Register %s to the %s factory"%(nickname, nickname)
+    return register
+
+
+def get_alias_func(base_class, nickname):
+    """Get registrator function that allow aliases.
+
+    Parameters
+    ----------
+    base_class : type
+        base class for classes that will be reigstered
+    nickname : str
+        nickname of base_class for logging
+
+    Returns
+    -------
+    a registrator function
+    """
+    register = get_register_func(base_class, nickname)
+
+    def alias(*aliases):
+        """alias registrator"""
+        def reg(klass):
+            """registrator function"""
+            for name in aliases:
+                register(klass, name)
+            return klass
+        return reg
+    return alias
+
+
+def get_create_func(base_class, nickname):
+    """Get creator function
+
+    Parameters
+    ----------
+    base_class : type
+        base class for classes that will be reigstered
+    nickname : str
+        nickname of base_class for logging
+
+    Returns
+    -------
+    a creator function
+    """
+    if base_class not in _REGISTRY:
+        _REGISTRY[base_class] = {}
+    registry = _REGISTRY[base_class]
+
+    def create(*args, **kwargs):
+        """Create instance from config"""
+        if len(args):
+            name = args[0]
+            args = args[1:]
+        else:
+            name = kwargs.pop(nickname)
+
+        if isinstance(name, base_class):
+            assert len(args) == 0 and len(kwargs) == 0, \
+                "%s is already an instance. Additional arguments are invalid"%(nickname)
+            return name
+
+        assert isinstance(name, string_types), "%s must be of string type"%nickname
+        name = name.lower()
+        assert name in registry, \
+            "%s is not registered. Please register with %s.register first"%(
+                str(name), nickname)
+        return registry[name](*args, **kwargs)
+
+    create.__doc__ = """Create a %s instance from config.
+        
+Parameters
+----------
+%s : str or %s instance
+    class name of desired instance. If is a instance,
+    it will be returned directly.
+**kwargs : dict
+    arguments to be passed to constructor"""%(nickname, nickname, base_class.__name__)
+
+    return create
diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py
index 714a656..4ec17b7 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -21,7 +21,7 @@
     data_iter = mx.io.NDArrayIter(data, label, batch_size=10, label_name='label')
     output = get_net(nclass)
     l = mx.symbol.Variable('label')
-    loss = mx.loss.cross_entropy_loss(output, l)
+    loss = mx.loss.softmax_cross_entropy_loss(output, l)
     mod = mx.mod.Module(loss)
     mod.fit(data_iter, eval_metric=loss.metric, num_epoch=200, optimizer_params={'learning_rate': 1.})
     assert mod.score(data_iter, loss.metric)[0][1] == 1.0
@@ -87,7 +87,7 @@
     output = get_net(nclass)
     l = mx.symbol.Variable('label')
     w = mx.symbol.Variable('w')
-    loss = mx.loss.cross_entropy_loss(output, l, sample_weight=w)
+    loss = mx.loss.softmax_cross_entropy_loss(output, l, sample_weight=w)
     mod = mx.mod.Module(loss)
     mod.fit(data_iter, eval_metric=loss.metric, num_epoch=200,
             optimizer_params={'learning_rate': 1.})
@@ -111,7 +111,7 @@
     output2 = mx.symbol.FullyConnected(act3, name='output2', num_hidden=1)
     l1 = mx.symbol.Variable('label1')
     l2 = mx.symbol.Variable('label2')
-    loss1 = mx.loss.cross_entropy_loss(output1, l1)
+    loss1 = mx.loss.softmax_cross_entropy_loss(output1, l1)
     loss2 = mx.loss.l2_loss(output2, l2)
     loss = mx.loss.multi_loss([loss1, loss2])
     mod = mx.mod.Module(loss)
@@ -124,7 +124,36 @@
     assert score[2][1] < 0.2
 
 
+def test_saveload():
+    nclass = 10
+    output = get_net(nclass)
+    l = mx.symbol.Variable('label')
+    loss = mx.loss.softmax_cross_entropy_loss(output, l)
+    assert mx.loss.create(**loss.get_config()).get_config() == loss.get_config()
+
+
+def test_saveload2():
+    mx.random.seed(1234)
+    np.random.seed(1234)
+    nclass = 10
+    N = 20
+    data = mx.random.uniform(-1, 1, shape=(N, nclass))
+    label = mx.nd.array(np.random.randint(0, nclass, size=(N,)), dtype='int32')
+    data_iter = mx.io.NDArrayIter(data, label, batch_size=10, label_name='label')
+    output = get_net(nclass)
+    l = mx.symbol.Variable('label')
+    loss = mx.loss.softmax_cross_entropy_loss(output, l)
+    mod = mx.mod.Module(loss)
+    mod.fit(data_iter, eval_metric=loss.metric, num_epoch=100, optimizer_params={'learning_rate': 1.})
+    mod.save_checkpoint('test', 100, save_optimizer_states=True)
+    mod = mx.mod.Module.load('test', 100, load_optimizer_states=True)
+    mod.fit(data_iter, eval_metric=mod._loss.metric, num_epoch=100, optimizer_params={'learning_rate': 1.})
+    assert mod.score(data_iter, loss.metric)[0][1] == 1.0
+
+
 if __name__ == '__main__':
+    test_saveload()
+    test_saveload2()
     test_multi_loss()
     test_sample_weight_loss()
     test_custom_loss()