blob: 26e0c8b428fdf94e2f666ea2920d3654de8a44bd [file] [log] [blame]
# pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access, too-many-branches
# pylint: disable=too-many-public-methods, too-many-statements
"""A `Module` implement the `BaseModule` API by wrapping a `Symbol` and one or
more `Executor` for data parallelization.
"""
import os
import json
import logging
import warnings
from .. import context as ctx
from .. import ndarray as nd
from .. import symbol as _sym
from .. import optimizer as opt
from .. import loss
from ..base import _Sentinel, __version__, string_types
from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform, InitDesc
from .base_module import BaseModule, _check_input_names, _parse_data_desc
class Module(BaseModule):
"""Module is a basic module that wrap a `Loss` or `Symbol`.
It is functionally the same as the `FeedForward` model.
Parameters
----------
symbol : Symbol
data_names : list of str
Default is `('data')` for a typical model used in image classification.
label_names : list of str
Default is `('softmax_label')` for a typical model used in image
classification.
logger : Logger
Default is `logging`.
context : Context or list of Context
Default is `cpu()`.
work_load_list : list of number
Default `None`, indicating uniform workload.
fixed_param_names: list of str
Default `None`, indicating no network parameters are fixed.
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
"""
def __init__(self, symbol, data_names=('data',), label_names=_Sentinel,
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None, **kwargs):
super(Module, self).__init__(logger=logger)
if isinstance(symbol, string_types):
symbol = _sym.load_json(symbol)
elif isinstance(symbol, dict):
symbol = loss.create(**symbol)
if isinstance(context, ctx.Context):
context = [context]
self._context = context
if work_load_list is None:
work_load_list = [1] * len(self._context)
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list
self._kwargs = kwargs
if isinstance(symbol, loss.BaseLoss):
self._kwargs['symbol'] = symbol.get_config()
else:
self._kwargs['symbol'] = symbol.tojson()
self._kwargs.update({
'data_names': data_names,
'fixed_param_names': fixed_param_names,
'state_names': state_names,
'__type__': 'module',
'__version__': __version__})
if isinstance(symbol, loss.BaseLoss):
self._loss = symbol
self._symbol = _sym.Group([self._loss.output_symbol, self._loss.loss_symbol])
num_output = len(self._loss.output_symbol.list_outputs())
num_loss = len(self._loss.loss_symbol.list_outputs())
self._output_range = (0, num_output)
self._loss_range = (num_output, num_output+num_loss)
assert label_names is _Sentinel, \
"label_names has been deprecated. Do not set."
label_names = self._loss.label_names
else:
self._symbol = symbol
self._loss = None
self._output_range = (0, len(self._symbol.list_outputs()))
if label_names is _Sentinel:
label_names = ('softmax_label',)
else:
warnings.warn(
"label_names has been deprecated. For prediction only, "
"do not set label_names. For training, please use the new "
"mxnet.loss.* classes for symbol", stacklevel=2)
data_names = list(data_names) if data_names is not None else []
label_names = list(label_names) if label_names is not None else []
state_names = list(state_names) if state_names is not None else []
fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
_check_input_names(self._symbol, data_names, "data", True)
_check_input_names(self._symbol, label_names, "label", False)
_check_input_names(self._symbol, state_names, "state", True)
_check_input_names(self._symbol, fixed_param_names, "fixed_param", True)
arg_names = self._symbol.list_arguments()
input_names = data_names + label_names + state_names
self._param_names = [x for x in arg_names if x not in input_names]
self._fixed_param_names = fixed_param_names
self._aux_names = self._symbol.list_auxiliary_states()
self._data_names = data_names
self._label_names = label_names
self._state_names = state_names
self._output_names = self._symbol.list_outputs()
self._arg_params = None
self._aux_params = None
self._params_dirty = False
self._optimizer = None
self._kvstore = None
self._update_on_kvstore = None
self._updater = None
self._preload_opt_states = None
self._grad_req = None
self._exec_group = None
self._data_shapes = None
self._label_shapes = None
@staticmethod
def load(prefix, epoch, load_optimizer_states=False, **kwargs):
"""Create a model from previously saved checkpoint.
For example, use::
mod = mx.mod.Module.load('test', 100, context=mx.gpu(0))
to load from "test-module.json" and "test-0100.params"
Parameters
----------
prefix : str
path prefix of saved model files. You should have
"prefix-symbol.json"/"prefix-module.json",
"prefix-xxxx.params", and optionally "prefix-xxxx.states",
where xxxx is the epoch number.
epoch : int
epoch to load.
load_optimizer_states : bool
whether to load optimizer states. Checkpoint needs
to have been made with save_optimizer_states=True.
context : Context or list of Context
Default is `cpu()`.
work_load_list : list of number
Default `None`, indicating uniform workload.
logger : Logger
Default is `logging`.
"""
sym, args, auxs = load_checkpoint(prefix, epoch)
if os.path.exists('%s-module.json'%prefix):
config = json.loads(open('%s-module.json'%prefix).read())
config.update(kwargs)
mod = Module(**config)
else:
mod = Module(sym, **kwargs)
mod._arg_params = args
mod._aux_params = auxs
mod.params_initialized = True
if load_optimizer_states:
mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
return mod
def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
"""Save current progress to checkpoint.
Use mx.callback.module_checkpoint as
epoch_end_callback to save during training.
Outputs 'prefix-module.json', 'prefix-symbol.json',
'prefix-(epoch).params', and optionally 'prefix-(epoch).states'.
Parameters
----------
prefix : str
The file prefix to checkpoint to
epoch : int
The current epoch number
save_optimizer_states : bool
Whether to save optimizer states for continue training
"""
if self._loss is not None:
self._loss.output_symbol.save('%s-symbol.json'%prefix)
else:
self._symbol.save('%s-symbol.json'%prefix)
with open('%s-module.json'%prefix, 'w') as fout:
json.dump(self._kwargs, fout)
param_name = '%s-%04d.params' % (prefix, epoch)
self.save_params(param_name)
logging.info('Saved checkpoint to \"%s\"', param_name)
if save_optimizer_states:
state_name = '%s-%04d.states' % (prefix, epoch)
self.save_optimizer_states(state_name)
logging.info('Saved optimizer state to \"%s\"', state_name)
def _reset_bind(self):
"""Internal function to reset binded state."""
if self.binded and self.params_initialized:
self._sync_params_from_devices()
self.binded = False
self._exec_group = None
self._data_shapes = None
self._label_shapes = None
@property
def data_names(self):
"""A list of names for data required by this module."""
return self._data_names
@property
def label_names(self):
"""A list of names for labels required by this module."""
return self._label_names
@property
def output_names(self):
"""A list of names for the outputs of this module."""
return self._output_names[self._output_range[0]:self._output_range[1]]
@property
def data_shapes(self):
"""Get data shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._data_shapes
@property
def label_shapes(self):
"""Get label shapes.
Returns
-------
A list of `(name, shape)` pairs. The return value could be `None` if
the module does not need labels, or if the module is not binded for
training (in this case, label information is not available).
"""
assert self.binded
return self._label_shapes
@property
def output_shapes(self):
"""Get output shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._exec_group.get_output_shapes()
def get_params(self):
"""Get current parameters.
Returns
-------
`(arg_params, aux_params)`, each a dictionary of name to parameters (in
`NDArray`) mapping.
"""
assert self.binded and self.params_initialized
if self._params_dirty:
self._sync_params_from_devices()
return (self._arg_params, self._aux_params)
def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False):
"""Initialize the parameters and auxiliary states.
Parameters
----------
initializer : Initializer
Called to initialize parameters if needed.
arg_params : dict
If not None, should be a dictionary of existing arg_params. Initialization
will be copied from that.
aux_params : dict
If not None, should be a dictionary of existing aux_params. Initialization
will be copied from that.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
"""
if self.params_initialized and not force_init:
warnings.warn("Parameters already initialized and force_init=False. "
"init_params call ignored.", stacklevel=2)
return
assert self.binded, 'call bind before initializing the parameters'
def _impl(name, arr, cache):
"""Internal helper for parameter initialization"""
if cache is not None:
if name in cache:
cache_arr = cache[name]
# just in case the cached array is just the target itself
if cache_arr is not arr:
cache_arr.copyto(arr)
else:
if not allow_missing:
raise RuntimeError("%s is not presented" % name)
if initializer != None:
initializer(name, arr)
else:
initializer(name, arr)
attrs = self._symbol.attr_dict()
for name, arr in self._arg_params.items():
desc = InitDesc(name, attrs.get(name, None))
_impl(desc, arr, arg_params)
for name, arr in self._aux_params.items():
desc = InitDesc(name, attrs.get(name, None))
_impl(desc, arr, aux_params)
self.params_initialized = True
self._params_dirty = False
# copy the initialized parameters to devices
self._exec_group.set_params(self._arg_params, self._aux_params)
def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True):
"""Assign parameter and aux state values.
Parameters
----------
arg_params : dict
Dictionary of name to value (`NDArray`) mapping.
aux_params : dict
Dictionary of name to value (`NDArray`) mapping.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
Examples
--------
An example of setting module parameters::
>>> sym, arg_params, aux_params = \
>>> mx.model.load_checkpoint(model_prefix, n_epoch_load)
>>> mod.set_params(arg_params=arg_params, aux_params=aux_params)
"""
if not allow_missing:
self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
return
if self.params_initialized and not force_init:
warnings.warn("Parameters already initialized and force_init=False. "
"set_params call ignored.", stacklevel=2)
return
self._exec_group.set_params(arg_params, aux_params)
# because we didn't update self._arg_params, they are dirty now.
self._params_dirty = True
self.params_initialized = True
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
grad_req='write'):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Parameters
----------
data_shapes : list of (str, tuple)
Typically is `data_iter.provide_data`.
label_shapes : list of (str, tuple)
Typically is `data_iter.provide_label`.
for_training : bool
Default is `True`. Whether the executors should be bind for training.
inputs_need_grad : bool
Default is `False`. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
force_rebind : bool
Default is `False`. This function does nothing if the executors are already
binded. But with this `True`, the executors will be forced to rebind.
shared_module : Module
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
"""
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:
self._reset_bind()
if self.binded:
self.logger.warning('Already binded, ignoring bind()')
return
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
self._grad_req = grad_req
if not for_training:
assert not inputs_need_grad
elif self._loss is None:
warnings.warn(
"Training with ***Output symbols is deprecated. "
"Please use mxnet.loss.* classes instead. See "
"mxnet.mod.Module's document for usage")
self._data_shapes, self._label_shapes = _parse_data_desc(
self.data_names, self.label_names, data_shapes, label_shapes)
if shared_module is not None:
assert isinstance(shared_module, Module) and \
shared_module.binded and shared_module.params_initialized
shared_group = shared_module._exec_group
else:
shared_group = None
self._exec_group = DataParallelExecutorGroup(
self._symbol, self._context, self._work_load_list, self._data_shapes,
self._label_shapes, self._param_names, for_training, inputs_need_grad,
shared_group, logger=self.logger, fixed_param_names=self._fixed_param_names,
grad_req=grad_req, state_names=self._state_names)
self._total_exec_bytes = self._exec_group._total_exec_bytes
if shared_module is not None:
self.params_initialized = True
self._arg_params = shared_module._arg_params
self._aux_params = shared_module._aux_params
elif self.params_initialized:
# if the parameters are already initialized, we are re-binding
# so automatically copy the already initialized params
self._exec_group.set_params(self._arg_params, self._aux_params)
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
aux_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.aux_arrays
]
self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
if shared_module is not None and shared_module.optimizer_initialized:
self.borrow_optimizer(shared_module)
def reshape(self, data_shapes, label_shapes=None):
"""Reshape the module for new input shapes.
Parameters
----------
data_shapes : list of (str, tuple)
Typically is `data_iter.provide_data`.
label_shapes : list of (str, tuple)
Typically is `data_iter.provide_label`.
"""
assert self.binded
self._data_shapes, self._label_shapes = _parse_data_desc(
self.data_names, self.label_names, data_shapes, label_shapes)
self._exec_group.reshape(self._data_shapes, self._label_shapes)
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False):
"""Install and initialize optimizers.
Parameters
----------
kvstore : str or KVStore
Default `'local'`.
optimizer : str or Optimizer
Default `'sgd'`
optimizer_params : dict
Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
just to avoid pylint warning of dangerous default values.
force_init : bool
Default `False`, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
"""
assert self.binded and self.params_initialized
if self.optimizer_initialized and not force_init:
self.logger.warning('optimizer already initialized, ignoring...')
return
(kvstore, update_on_kvstore) = \
_create_kvstore(kvstore, len(self._context), self._arg_params)
batch_size = self._exec_group.batch_size
if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type:
batch_size *= kvstore.num_workers
rescale_grad = 1.0/batch_size
if isinstance(optimizer, str):
idx2name = {}
if update_on_kvstore:
idx2name.update(enumerate(self._exec_group.param_names))
else:
for k in range(len(self._context)):
idx2name.update({i*len(self._context)+k: n
for i, n in enumerate(self._exec_group.param_names)})
optimizer_params = dict(optimizer_params)
if 'rescale_grad' not in optimizer_params:
optimizer_params['rescale_grad'] = rescale_grad
optimizer = opt.create(optimizer,
sym=self._symbol, param_idx2name=idx2name,
**optimizer_params)
else:
assert isinstance(optimizer, opt.Optimizer)
if optimizer.rescale_grad != rescale_grad:
#pylint: disable=no-member
warnings.warn(
"Optimizer created manually outside Module but rescale_grad " +
"is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%(
optimizer.rescale_grad, rescale_grad) +
"Is this intended?", stacklevel=2)
self._optimizer = optimizer
self._kvstore = kvstore
self._update_on_kvstore = update_on_kvstore
self._updater = None
if kvstore:
# copy initialized local parameters to kvstore
_initialize_kvstore(kvstore=kvstore,
param_arrays=self._exec_group.param_arrays,
arg_params=self._arg_params,
param_names=self._param_names,
update_on_kvstore=update_on_kvstore)
if update_on_kvstore:
kvstore.set_optimizer(self._optimizer)
else:
self._updater = opt.get_updater(optimizer)
self.optimizer_initialized = True
if self._preload_opt_states is not None:
self.load_optimizer_states(self._preload_opt_states)
self._preload_opt_states = None
def borrow_optimizer(self, shared_module):
"""Borrow optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
Parameters
----------
shared_module : Module
"""
assert shared_module.optimizer_initialized
self._optimizer = shared_module._optimizer
self._kvstore = shared_module._kvstore
self._update_on_kvstore = shared_module._update_on_kvstore
self._updater = shared_module._updater
self.optimizer_initialized = True
def forward(self, data_batch, is_train=None):
"""Forward computation.
Parameters
----------
data_batch : DataBatch
Could be anything with similar API implemented.
is_train : bool
Default is `None`, which means `is_train` takes the value of `self.for_training`.
"""
assert self.binded and self.params_initialized
self._exec_group.forward(data_batch, is_train)
def backward(self, out_grads=None):
"""Backward computation.
Parameters
----------
out_grads : NDArray or list of NDArray, optional
Gradient on the outputs to be propagated back.
This parameter is only needed when bind is called
on outputs that are not a loss function.
"""
assert self.binded and self.params_initialized
self._exec_group.backward(out_grads=out_grads)
def update(self):
"""Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
if self._update_on_kvstore:
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore)
else:
_update_params(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
updater=self._updater,
num_device=len(self._context),
kvstore=self._kvstore)
def get_outputs(self, merge_multi_context=True):
"""Get outputs of the previous forward computation.
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are `NDArray`. When `merge_multi_context` is `False`, those `NDArray`
might live on different devices.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArray or list of list of NDArray
Output
"""
assert self.binded and self.params_initialized
return self._exec_group.get_outputs(merge_multi_context=merge_multi_context,
begin=self._output_range[0],
end=self._output_range[1])
def get_input_grads(self, merge_multi_context=True):
"""Get the gradients with respect to the inputs of the module.
If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
elements are `NDArray`.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArray or list of list of NDArray
Input gradients
"""
assert self.binded and self.params_initialized and self.inputs_need_grad
return self._exec_group.get_input_grads(merge_multi_context=merge_multi_context)
def get_states(self, merge_multi_context=True):
"""Get states from all devices
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are `NDArray`.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the states
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArray or list of list of NDArray
States
"""
assert self.binded and self.params_initialized
return self._exec_group.get_states(merge_multi_context=merge_multi_context)
def set_states(self, states=None, value=None):
"""Set value for states. Only one of states & value can be specified.
Parameters
----------
states : list of list of NDArrays
source states arrays formatted like [[state1_dev1, state1_dev2],
[state2_dev1, state2_dev2]].
value : number
a single scalar value for all state arrays.
"""
assert self.binded and self.params_initialized
self._exec_group.set_states(states, value)
def update_metric(self, eval_metric, labels):
"""Evaluate and accumulate evaluation metric on outputs of the last forward computation.
Parameters
----------
eval_metric : EvalMetric
labels : list of NDArray
Typically `data_batch.label`.
"""
self._exec_group.update_metric(eval_metric, labels)
def _sync_params_from_devices(self):
"""Synchronize parameters from devices to CPU. This function should be called after
calling `update` that updates the parameters on the devices, before one can read the
latest parameters from `self._arg_params` and `self._aux_params`.
"""
self._exec_group.get_params(self._arg_params, self._aux_params)
self._params_dirty = False
def save_optimizer_states(self, fname):
"""Save optimizer (updater) state to file
Parameters
----------
fname : str
Path to output states file.
"""
assert self.optimizer_initialized
if self._update_on_kvstore:
self._kvstore.save_optimizer_states(fname)
else:
with open(fname, 'wb') as fout:
fout.write(self._updater.get_states())
def load_optimizer_states(self, fname):
"""Load optimizer (updater) state from file
Parameters
----------
fname : str
Path to input states file.
"""
assert self.optimizer_initialized
if self._update_on_kvstore:
self._kvstore.load_optimizer_states(fname)
else:
self._updater.set_states(open(fname, 'rb').read())
def install_monitor(self, mon):
""" Install monitor on all executors """
assert self.binded
self._exec_group.install_monitor(mon)