blob: 7e66036efd55286f89805fe158ac83510d9f3974 [file]
# pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access, too-many-branches
# pylint: disable=too-many-public-methods
"""A `Module` implement the `BaseModule` API by wrapping a `Symbol` and one or
more `Executor` for data parallelization.
"""
import logging
from .. import context as ctx
from .. import ndarray as nd
from .. import optimizer as opt
from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform
from .base_module import BaseModule
from ..io import DataDesc
class Module(BaseModule):
"""Module is a basic module that wrap a `Symbol`. It is functionally the same
as the `FeedForward` model, except under the module API.
Parameters
----------
symbol : Symbol
data_names : list of str
Default is `('data')` for a typical model used in image classification.
label_names : list of str
Default is `('softmax_label')` for a typical model used in image
classification.
logger : Logger
Default is `logging`.
context : Context or list of Context
Default is `cpu()`.
work_load_list : list of number
Default `None`, indicating uniform workload.
fixed_param_names: list of str
Default `None`, indicating no network parameters are fixed.
"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None, fixed_param_names=None):
super(Module, self).__init__(logger=logger)
if isinstance(context, ctx.Context):
context = [context]
self._context = context
if work_load_list is None:
work_load_list = [1] * len(self._context)
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list
self._symbol = symbol
data_names = list(data_names)
label_names = list(label_names) if label_names is not None else []
arg_names = symbol.list_arguments()
input_names = data_names + label_names
self._param_names = [x for x in arg_names if x not in input_names]
self._fixed_param_names = fixed_param_names
self._aux_names = symbol.list_auxiliary_states()
self._data_names = data_names
self._label_names = label_names
self._output_names = symbol.list_outputs()
self._arg_params = None
self._aux_params = None
self._params_dirty = False
self._optimizer = None
self._kvstore = None
self._update_on_kvstore = None
self._updater = None
self._preload_opt_states = None
self._exec_group = None
self._data_shapes = None
self._label_shapes = None
@staticmethod
def load(prefix, epoch, load_optimizer_states=False, **kwargs):
"""Create a model from previously saved checkpoint.
Parameters
----------
prefix : str
path prefix of saved model files. You should have
"prefix-symbol.json", "prefix-xxxx.params", and
optionally "prefix-xxxx.states", where xxxx is the
epoch number.
epoch : int
epoch to load.
load_optimizer_states : bool
whether to load optimizer states. Checkpoint needs
to have been made with save_optimizer_states=True.
data_names : list of str
Default is `('data')` for a typical model used in image classification.
label_names : list of str
Default is `('softmax_label')` for a typical model used in image
classification.
logger : Logger
Default is `logging`.
context : Context or list of Context
Default is `cpu()`.
work_load_list : list of number
Default `None`, indicating uniform workload.
fixed_param_names: list of str
Default `None`, indicating no network parameters are fixed.
"""
sym, args, auxs = load_checkpoint(prefix, epoch)
mod = Module(symbol=sym, **kwargs)
mod._arg_params = args
mod._aux_params = auxs
mod.params_initialized = True
if load_optimizer_states:
mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
return mod
def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
"""Save current progress to checkpoint.
Use mx.callback.module_checkpoint as epoch_end_callback to save during training.
Parameters
----------
prefix : str
The file prefix to checkpoint to
epoch : int
The current epoch number
save_optimizer_states : bool
Whether to save optimizer states for continue training
"""
self._symbol.save('%s-symbol.json'%prefix)
param_name = '%s-%04d.params' % (prefix, epoch)
self.save_params(param_name)
logging.info('Saved checkpoint to \"%s\"', param_name)
if save_optimizer_states:
state_name = '%s-%04d.states' % (prefix, epoch)
self.save_optimizer_states(state_name)
logging.info('Saved optimizer state to \"%s\"', state_name)
def _reset_bind(self):
"""Internal function to reset binded state."""
self.binded = False
self._exec_group = None
self._data_shapes = None
self._label_shapes = None
@property
def data_names(self):
"""A list of names for data required by this module."""
return self._data_names
@property
def output_names(self):
"""A list of names for the outputs of this module."""
return self._output_names
@property
def data_shapes(self):
"""Get data shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._data_shapes
@property
def label_shapes(self):
"""Get label shapes.
Returns
-------
A list of `(name, shape)` pairs. The return value could be `None` if
the module does not need labels, or if the module is not binded for
training (in this case, label information is not available).
"""
assert self.binded
return self._label_shapes
@property
def output_shapes(self):
"""Get output shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._exec_group.get_output_shapes()
def get_params(self):
"""Get current parameters.
Returns
-------
`(arg_params, aux_params)`, each a dictionary of name to parameters (in
`NDArray`) mapping.
"""
assert self.binded and self.params_initialized
if self._params_dirty:
self._sync_params_from_devices()
return (self._arg_params, self._aux_params)
def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False):
"""Initialize the parameters and auxiliary states.
Parameters
----------
initializer : Initializer
Called to initialize parameters if needed.
arg_params : dict
If not None, should be a dictionary of existing arg_params. Initialization
will be copied from that.
aux_params : dict
If not None, should be a dictionary of existing aux_params. Initialization
will be copied from that.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
"""
if self.params_initialized and not force_init:
return
assert self.binded, 'call bind before initializing the parameters'
if self._arg_params is None:
param_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
if self._aux_params is None:
aux_arrays = [
nd.zeros(x[0].shape, dtype=x[0].dtype)
for x in self._exec_group.aux_arrays
]
self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)}
def _impl(name, arr, cache):
"""Internal helper for parameter initialization"""
if cache is not None:
if name in cache:
cache_arr = cache[name]
# just in case the cached array is just the target itself
if cache_arr is not arr:
cache_arr.copyto(arr)
else:
if not allow_missing:
raise RuntimeError("%s is not presented" % name)
if initializer != None:
initializer(name, arr)
else:
initializer(name, arr)
for name, arr in self._arg_params.items():
_impl(name, arr, arg_params)
for name, arr in self._aux_params.items():
_impl(name, arr, aux_params)
self.params_initialized = True
self._params_dirty = False
# copy the initialized parameters to devices
self._exec_group.set_params(self._arg_params, self._aux_params)
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
grad_req='write'):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Parameters
----------
data_shapes : list of (str, tuple)
Typically is `data_iter.provide_data`.
label_shapes : list of (str, tuple)
Typically is `data_iter.provide_label`.
for_training : bool
Default is `True`. Whether the executors should be bind for training.
inputs_need_grad : bool
Default is `False`. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
force_rebind : bool
Default is `False`. This function does nothing if the executors are already
binded. But with this `True`, the executors will be forced to rebind.
shared_module : Module
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
"""
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:
self._reset_bind()
if self.binded:
self.logger.warning('Already binded, ignoring bind()')
return
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
if not for_training:
assert not inputs_need_grad
else:
pass
# this is not True, as some module might not contains a loss function
# that consumes the labels
# assert label_shapes is not None
self._data_shapes = \
[x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
if label_shapes is not None:
self._label_shapes = \
[x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
else:
self._label_shapes = None
if shared_module is not None:
assert isinstance(shared_module, Module) and \
shared_module.binded and shared_module.params_initialized
shared_group = shared_module._exec_group
else:
shared_group = None
input_types = {x.name: x.dtype for x in self._data_shapes}
if self._label_shapes is not None:
input_types.update({x.name: x.dtype for x in self._label_shapes})
self._exec_group = DataParallelExecutorGroup(self._symbol, self._context,
self._work_load_list, self._data_shapes,
self._label_shapes, self._param_names,
for_training, inputs_need_grad,
shared_group, logger=self.logger,
fixed_param_names=self._fixed_param_names,
grad_req=grad_req, input_types=input_types)
if shared_module is not None:
self.params_initialized = True
self._arg_params = shared_module._arg_params
self._aux_params = shared_module._aux_params
elif self.params_initialized:
# if the parameters are already initialized, we are re-binding
# so automatically copy the already initialized params
self._exec_group.set_params(self._arg_params, self._aux_params)
if shared_module is not None and shared_module.optimizer_initialized:
self.borrow_optimizer(shared_module)
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False):
"""Install and initialize optimizers.
Parameters
----------
kvstore : str or KVStore
Default `'local'`.
optimizer : str or Optimizer
Default `'sgd'`
optimizer_params : dict
Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
just to avoid pylint warning of dangerous default values.
force_init : bool
Default `False`, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
"""
assert self.binded and self.params_initialized
if self.optimizer_initialized and not force_init:
self.logger.warning('optimizer already initialized, ignoring...')
return
(kvstore, update_on_kvstore) = \
_create_kvstore(kvstore, len(self._context), self._arg_params)
if isinstance(optimizer, str):
batch_size = self._exec_group.batch_size
if kvstore and kvstore.type == 'dist_sync':
batch_size *= kvstore.num_workers
idx2name = {}
if update_on_kvstore:
idx2name.update(enumerate(self._exec_group.param_names))
else:
for k in range(len(self._context)):
idx2name.update({i*len(self._context)+k: n
for i, n in enumerate(self._exec_group.param_names)})
optimizer_params = dict(optimizer_params)
if 'rescale_grad' not in optimizer_params:
optimizer_params['rescale_grad'] = 1.0/batch_size
optimizer = opt.create(optimizer,
sym=self.symbol, param_idx2name=idx2name,
**optimizer_params)
else:
assert isinstance(optimizer, opt.Optimizer)
self._optimizer = optimizer
self._kvstore = kvstore
self._update_on_kvstore = update_on_kvstore
self._updater = None
if kvstore:
# copy initialized local parameters to kvstore
_initialize_kvstore(kvstore=kvstore,
param_arrays=self._exec_group.param_arrays,
arg_params=self._arg_params,
param_names=self._param_names,
update_on_kvstore=update_on_kvstore)
if update_on_kvstore:
kvstore.set_optimizer(self._optimizer)
else:
self._updater = opt.get_updater(optimizer)
self.optimizer_initialized = True
if self._preload_opt_states is not None:
self.load_optimizer_states(self._preload_opt_states)
self._preload_opt_states = None
def borrow_optimizer(self, shared_module):
"""Borrow optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
Parameters
----------
shared_module : Module
"""
assert shared_module.optimizer_initialized
self._optimizer = shared_module._optimizer
self._kvstore = shared_module._kvstore
self._update_on_kvstore = shared_module._update_on_kvstore
self._updater = shared_module._updater
self.optimizer_initialized = True
def forward(self, data_batch, is_train=None):
"""Forward computation.
Parameters
----------
data_batch : DataBatch
Could be anything with similar API implemented.
is_train : bool
Default is `None`, which means `is_train` takes the value of `self.for_training`.
"""
assert self.binded and self.params_initialized
self._exec_group.forward(data_batch, is_train)
def backward(self, out_grads=None):
"""Backward computation.
Parameters
----------
out_grads : NDArray or list of NDArray, optional
Gradient on the outputs to be propagated back.
This parameter is only needed when bind is called
on outputs that are not a loss function.
"""
assert self.binded and self.params_initialized
self._exec_group.backward(out_grads=out_grads)
def update(self):
"""Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
if self._update_on_kvstore:
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore)
else:
_update_params(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
updater=self._updater,
num_device=len(self._context),
kvstore=self._kvstore)
def get_outputs(self, merge_multi_context=True):
"""Get outputs of the previous forward computation.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are `NDArray`.
"""
assert self.binded and self.params_initialized
return self._exec_group.get_outputs(merge_multi_context=merge_multi_context)
def get_input_grads(self, merge_multi_context=True):
"""Get the gradients with respect to the inputs of the module.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
elements are `NDArray`.
"""
assert self.binded and self.params_initialized and self.inputs_need_grad
return self._exec_group.get_input_grads(merge_multi_context=merge_multi_context)
def update_metric(self, eval_metric, labels):
"""Evaluate and accumulate evaluation metric on outputs of the last forward computation.
Parameters
----------
eval_metric : EvalMetric
labels : list of NDArray
Typically `data_batch.label`.
"""
self._exec_group.update_metric(eval_metric, labels)
def _sync_params_from_devices(self):
"""Synchronize parameters from devices to CPU. This function should be called after
calling `update` that updates the parameters on the devices, before one can read the
latest parameters from `self._arg_params` and `self._aux_params`.
"""
self._exec_group.get_params(self._arg_params, self._aux_params)
def save_optimizer_states(self, fname):
"""Save optimizer (updater) state to file
Parameters
----------
fname : str
Path to output states file.
"""
assert self.optimizer_initialized
if self._update_on_kvstore:
self._kvstore.save_optimizer_states(fname)
else:
with open(fname, 'wb') as fout:
fout.write(self._updater.get_states())
def load_optimizer_states(self, fname):
"""Load optimizer (updater) state from file
Parameters
----------
fname : str
Path to input states file.
"""
assert self.optimizer_initialized
if self._update_on_kvstore:
self._kvstore.load_optimizer_states(fname)
else:
self._updater.set_states(open(fname, 'rb').read())
def install_monitor(self, mon):
""" Install monitor on all executors """
assert self.binded
self._exec_group.install_monitor(mon)