blob: a13ed5be9bd69309bc44fd44601249b835f05a13 [file] [log] [blame]
# pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access
# pylint: disable=too-many-public-methods
"""A `BucketingModule` implement the `BaseModule` API, and allows multiple
symbols to be used depending on the `bucket_key` provided by each different
mini-batch of data.
"""
import logging
import warnings
from .. import context as ctx
from ..initializer import Uniform
from .base_module import BaseModule, _check_input_names
from .module import Module
class BucketingModule(BaseModule):
"""A bucketing module is a module that support bucketing.
Parameters
----------
sym_gen : function
A function when called with a bucket key, returns a triple
`(symbol, data_names, label_names)`.
default_bucket_key : str (or any python object)
The key for the default bucket.
logger : Logger
context : Context or list of Context
Default `cpu()`
work_load_list : list of number
Default `None`, indicating uniform workload.
fixed_param_names: list of str
Default `None`, indicating no network parameters are fixed.
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
"""
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
super(BucketingModule, self).__init__(logger=logger)
assert default_bucket_key is not None
self._default_bucket_key = default_bucket_key
self._sym_gen = sym_gen
symbol, data_names, label_names = sym_gen(default_bucket_key)
data_names = list(data_names) if data_names is not None else []
label_names = list(label_names) if label_names is not None else []
state_names = list(state_names) if state_names is not None else []
fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
_check_input_names(symbol, data_names, "data", True)
_check_input_names(symbol, label_names, "label", False)
_check_input_names(symbol, state_names, "state", True)
_check_input_names(symbol, fixed_param_names, "fixed_param", True)
self._fixed_param_names = fixed_param_names
self._state_names = state_names
self._context = context
self._work_load_list = work_load_list
self._buckets = {}
self._curr_module = None
self._curr_bucket_key = None
self._params_dirty = False
def _reset_bind(self):
"""Internal utility function to reset binding."""
self.binded = False
self._buckets = {}
self._curr_module = None
self._curr_bucket_key = None
@property
def data_names(self):
"""A list of names for data required by this module."""
if self.binded:
return self._curr_module.data_names
else:
_, data_names, _ = self._sym_gen(self._default_bucket_key)
return data_names
@property
def output_names(self):
"""A list of names for the outputs of this module."""
if self.binded:
return self._curr_module.output_names
else:
symbol, _, _ = self._sym_gen(self._default_bucket_key)
return symbol.list_outputs()
@property
def data_shapes(self):
"""Get data shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._curr_module.data_shapes
@property
def label_shapes(self):
"""Get label shapes.
Returns
-------
A list of `(name, shape)` pairs. The return value could be `None` if
the module does not need labels, or if the module is not binded for
training (in this case, label information is not available).
"""
assert self.binded
return self._curr_module.label_shapes
@property
def output_shapes(self):
"""Get output shapes.
Returns
-------
A list of `(name, shape)` pairs.
"""
assert self.binded
return self._curr_module.output_shapes
def get_params(self):
"""Get current parameters.
Returns
-------
`(arg_params, aux_params)`, each a dictionary of name to parameters (in
`NDArray`) mapping.
"""
assert self.binded and self.params_initialized
self._curr_module._params_dirty = self._params_dirty
params = self._curr_module.get_params()
self._params_dirty = False
return params
def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True):
"""Assign parameter and aux state values.
Parameters
----------
arg_params : dict
Dictionary of name to value (`NDArray`) mapping.
aux_params : dict
Dictionary of name to value (`NDArray`) mapping.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
Examples
--------
An example of setting module parameters::
>>> sym, arg_params, aux_params = \
>>> mx.model.load_checkpoint(model_prefix, n_epoch_load)
>>> mod.set_params(arg_params=arg_params, aux_params=aux_params)
"""
if not allow_missing:
self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
return
if self.params_initialized and not force_init:
warnings.warn("Parameters already initialized and force_init=False. "
"set_params call ignored.", stacklevel=2)
return
self._curr_module.set_params(arg_params, aux_params, allow_missing=allow_missing,
force_init=force_init)
# because we didn't update self._arg_params, they are dirty now.
self._params_dirty = True
self.params_initialized = True
def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False):
"""Initialize parameters.
Parameters
----------
initializer : Initializer
arg_params : dict
Default `None`. Existing parameters. This has higher priority than `initializer`.
aux_params : dict
Default `None`. Existing auxiliary states. This has higher priority than `initializer`.
allow_missing : bool
Allow missing values in `arg_params` and `aux_params` (if not `None`). In this case,
missing values will be filled with `initializer`.
force_init : bool
Default `False`.
"""
if self.params_initialized and not force_init:
return
assert self.binded, 'call bind before initializing the parameters'
self._curr_module.init_params(initializer=initializer, arg_params=arg_params,
aux_params=aux_params, allow_missing=allow_missing,
force_init=force_init)
self._params_dirty = False
self.params_initialized = True
def get_states(self, merge_multi_context=True):
"""Get states from all devices
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the states
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are `NDArray`.
"""
assert self.binded and self.params_initialized
return self._curr_module.get_states(merge_multi_context=merge_multi_context)
def set_states(self, states=None, value=None):
"""Set value for states. Only one of states & value can be specified.
Parameters
----------
states : list of list of NDArrays
source states arrays formatted like [[state1_dev1, state1_dev2],
[state2_dev1, state2_dev2]].
value : number
a single scalar value for all state arrays.
"""
assert self.binded and self.params_initialized
self._curr_module.set_states(states, value)
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
grad_req='write'):
"""Binding for a `BucketingModule` means setting up the buckets and bind the
executor for the default bucket key. Executors corresponding to other keys are
binded afterwards with `switch_bucket`.
Parameters
----------
data_shapes : list of (str, tuple)
This should correspond to the symbol for the default bucket.
label_shapes : list of (str, tuple)
This should correspond to the symbol for the default bucket.
for_training : bool
Default is `True`.
inputs_need_grad : bool
Default is `False`.
force_rebind : bool
Default is `False`.
shared_module : BucketingModule
Default is `None`. This value is currently not used.
grad_req : str, list of str, dict of str to str
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
bucket_key : str (or any python object)
bucket key for binding. by default use the default_bucket_key
"""
# in case we already initialized params, keep it
if self.params_initialized:
arg_params, aux_params = self.get_params()
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:
self._reset_bind()
if self.binded:
self.logger.warning('Already binded, ignoring bind()')
return
assert shared_module is None, 'shared_module for BucketingModule is not supported'
self.for_training = for_training
self.inputs_need_grad = inputs_need_grad
self.binded = True
symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
module = Module(symbol, data_names, label_names, logger=self.logger,
context=self._context, work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None, grad_req=grad_req)
self._curr_module = module
self._curr_bucket_key = self._default_bucket_key
self._buckets[self._default_bucket_key] = module
# copy back saved params, if already initialized
if self.params_initialized:
self.set_params(arg_params, aux_params)
def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
"""Switch to a different bucket. This will change `self.curr_module`.
Parameters
----------
bucket_key : str (or any python object)
The key of the target bucket.
data_shapes : list of (str, tuple)
Typically `data_batch.provide_data`.
label_shapes : list of (str, tuple)
Typically `data_batch.provide_label`.
"""
assert self.binded, 'call bind before switching bucket'
if not bucket_key in self._buckets:
symbol, data_names, label_names = self._sym_gen(bucket_key)
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
self._buckets[bucket_key] = module
self._curr_module = self._buckets[bucket_key]
self._curr_bucket_key = bucket_key
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),),
force_init=False):
"""Install and initialize optimizers.
Parameters
----------
kvstore : str or KVStore
Default `'local'`.
optimizer : str or Optimizer
Default `'sgd'`
optimizer_params : dict
Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
just to avoid pylint warning of dangerous default values.
force_init : bool
Default `False`, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
"""
assert self.binded and self.params_initialized
if self.optimizer_initialized and not force_init:
self.logger.warning('optimizer already initialized, ignoring.')
return
self._curr_module.init_optimizer(kvstore, optimizer, optimizer_params,
force_init=force_init)
for mod in self._buckets.values():
if mod is not self._curr_module:
mod.borrow_optimizer(self._curr_module)
self.optimizer_initialized = True
def prepare(self, data_batch):
'''Prepare a data batch for forward.
Parameters
----------
data_batch : DataBatch
'''
# perform bind if haven't done so
assert self.binded and self.params_initialized
bucket_key = data_batch.bucket_key
original_bucket_key = self._curr_bucket_key
data_shapes = data_batch.provide_data
label_shapes = data_batch.provide_label
self.switch_bucket(bucket_key, data_shapes, label_shapes)
# switch back
self.switch_bucket(original_bucket_key, None, None)
def forward(self, data_batch, is_train=None):
"""Forward computation.
Parameters
----------
data_batch : DataBatch
is_train : bool
Default is `None`, in which case `is_train` is take as `self.for_training`.
"""
assert self.binded and self.params_initialized
self.switch_bucket(data_batch.bucket_key, data_batch.provide_data,
data_batch.provide_label)
self._curr_module.forward(data_batch, is_train=is_train)
def backward(self, out_grads=None):
"""Backward computation."""
assert self.binded and self.params_initialized
self._curr_module.backward(out_grads=out_grads)
def update(self):
"""Update parameters according to installed optimizer and the gradient computed
in the previous forward-backward cycle.
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
self._params_dirty = True
self._curr_module.update()
def get_outputs(self, merge_multi_context=True):
"""Get outputs from a previous forward computation.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are numpy arrays.
"""
assert self.binded and self.params_initialized
return self._curr_module.get_outputs(merge_multi_context=merge_multi_context)
def get_input_grads(self, merge_multi_context=True):
"""Get the gradients with respect to the inputs of the module.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
elements are `NDArray`.
"""
assert self.binded and self.params_initialized and self.inputs_need_grad
return self._curr_module.get_input_grads(merge_multi_context=merge_multi_context)
def update_metric(self, eval_metric, labels):
"""Evaluate and accumulate evaluation metric on outputs of the last forward computation.
Parameters
----------
eval_metric : EvalMetric
labels : list of NDArray
Typically `data_batch.label`.
"""
assert self.binded and self.params_initialized
self._curr_module.update_metric(eval_metric, labels)
@property
def symbol(self):
"""The symbol of the current bucket being used."""
assert self.binded
return self._curr_module.symbol
def install_monitor(self, mon):
""" Install monitor on all executors """
assert self.binded
for mod in self._buckets.values():
mod.install_monitor(mon)