blob: 238b053636e9d411a58ab472695c433dba35cd4b [file] [log] [blame]
# pylint: disable=fixme, too-many-arguments, too-many-locals, too-many-public-methods, too-many-branches
"""`BaseModule` defines an API for modules."""
import time
import logging
import warnings
from .. import metric
from .. import ndarray
from ..context import cpu
from ..model import BatchEndParam
from ..initializer import Uniform
from ..io import DataDesc
def _as_list(obj):
"""A utility function that treat the argument as a list.
Parameters
----------
obj : object
Returns
-------
If `obj` is a list, return it. Otherwise, return `[obj]` as a single-element list.
"""
if isinstance(obj, list):
return obj
else:
return [obj]
def _check_input_names(symbol, names, typename, throw):
"""Check that all input names are in symbol's argument"""
args = symbol.list_arguments()
for name in names:
if name in args:
continue
candidates = [arg for arg in args if
not arg.endswith('_weight') and
not arg.endswith('_bias') and
not arg.endswith('_gamma') and
not arg.endswith('_beta')]
msg = "\033[91mYou created Module with Module(..., %s_names=%s) but " \
"input with name '%s' is not found in symbol.list_arguments(). " \
"Did you mean one of:\n\t%s\033[0m"%(
typename, str(names), name, '\n\t'.join(candidates))
if throw:
raise ValueError(msg)
else:
warnings.warn(msg)
def _check_names_match(data_names, data_shapes, name, throw):
"""Check that input names matches input data descriptors"""
actual = [x[0] for x in data_shapes]
if sorted(data_names) != sorted(actual):
msg = "Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)"%(
name, name, str(data_shapes), str(data_names))
if throw:
raise ValueError(msg)
else:
warnings.warn(msg)
def _parse_data_desc(data_names, label_names, data_shapes, label_shapes):
"""parse data_shapes into DataDesc format and check that names match"""
data_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in data_shapes]
_check_names_match(data_names, data_shapes, 'data', True)
if label_shapes is not None:
label_shapes = [x if isinstance(x, DataDesc) else DataDesc(*x) for x in label_shapes]
_check_names_match(label_names, label_shapes, 'label', False)
else:
_check_names_match(label_names, [], 'label', False)
return data_shapes, label_shapes
class BaseModule(object):
"""The base class of a modules.
A module represents a computation component. The design purpose of a module
is that it abstract a computation "machine", that one can run forward,
backward, update parameters, etc. We aim to make the APIs easy to use,
especially in the case when we need to use imperative API to work with
multiple modules (e.g. stochastic depth network).
A module has several states:
- Initial state. Memory is not allocated yet, not ready for computation yet.
- Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated,
ready for computation.
- Parameter initialized. For modules with parameters, doing computation before initializing
the parameters might result in undefined outputs.
- Optimizer installed. An optimizer can be installed to a module. After this, the parameters
of the module can be updated according to the optimizer after gradients are computed
(forward-backward).
In order for a module to interact with others, a module should be able to report the
following information in its raw stage (before binded)
- `data_names`: list of string indicating the names of required data.
- `output_names`: list of string indicating the names of required outputs.
And also the following richer information after binded:
- state information
- `binded`: `bool`, indicating whether the memory buffers needed for computation
has been allocated.
- `for_training`: whether the module is binded for training (if binded).
- `params_initialized`: `bool`, indicating whether the parameters of this modules
has been initialized.
- `optimizer_initialized`: 'bool`, indicating whether an optimizer is defined
and initialized.
- `inputs_need_grad`: `bool`, indicating whether gradients with respect to the
input data is needed. Might be useful when implementing composition of modules.
- input/output information
- `data_shapes`: a list of `(name, shape)`. In theory, since the memory is allocated,
we could directly provide the data arrays. But in the case of data parallelization,
the data arrays might not be of the same shape as viewed from the external world.
- `label_shapes`: a list of `(name, shape)`. This might be `[]` if the module does
not need labels (e.g. it does not contains a loss function at the top), or a module
is not binded for training.
- `output_shapes`: a list of `(name, shape)` for outputs of the module.
- parameters (for modules with parameters)
- `get_params()`: return a tuple `(arg_params, aux_params)`. Each of those
is a dictionary of name to `NDArray` mapping. Those `NDArray` always lives on
CPU. The actual parameters used for computing might live on other devices (GPUs),
this function will retrieve (a copy of) the latest parameters. Therefore, modifying
- `set_params(arg_params, aux_params)`: assign parameters to the devices
doing the computation.
- `init_params(...)`: a more flexible interface to assign or initialize the parameters.
- setup
- `bind()`: prepare environment for computation.
- `init_optimizer()`: install optimizer for parameter updating.
- computation
- `forward(data_batch)`: forward operation.
- `backward(out_grads=None)`: backward operation.
- `update()`: update parameters according to installed optimizer.
- `get_outputs()`: get outputs of the previous forward operation.
- `get_input_grads()`: get the gradients with respect to the inputs computed
in the previous backward operation.
- `update_metric(metric, labels)`: update performance metric for the previous forward
computed results.
- other properties (mostly for backward compatability)
- `symbol`: the underlying symbolic graph for this module (if any)
This property is not necessarily constant. For example, for `BucketingModule`,
this property is simply the *current* symbol being used. For other modules,
this value might not be well defined.
When those intermediate-level API are implemented properly, the following
high-level API will be automatically available for a module:
- `fit`: train the module parameters on a data set
- `predict`: run prediction on a data set and collect outputs
- `score`: run prediction on a data set and evaluate performance
To create a module for training classification::
data = mx.sym.Variable('data')
output = mx.sym.FullyConnected(data, num_hidden=10)
label = mx.sym.Variable('label')
loss = mx.loss.softmax_cross_entropy_loss(output, label)
model = mx.mod.Module(loss, data_names=('data',))
model.fit(..., eval_metric=loss.metric)
model.score(..., eval_metric=loss.metric)
To create a module for prediction only::
data = mx.sym.Variable('data')
output = mx.sym.FullyConnected(data, num_hidden=10)
model = mx.mod.Module(output, data_names=('data',))
model.bind(data_shapes=[('data', (128, 100))], label_shapes=None)
model.load_params('save-0001.params')
model.predict(...)
You can also load from saved checkpoints::
model.save_checkpoint('save', 1)
model2 = mx.mod.Module.load('save', 1, context=mx.cpu(0))
"""
def __init__(self, logger=logging):
self.logger = logger
self.binded = False
self.for_training = False
self.inputs_need_grad = False
self.params_initialized = False
self.optimizer_initialized = False
self._symbol = None
self._total_exec_bytes = 0
################################################################################
# High Level API
################################################################################
def forward_backward(self, data_batch):
"""A convenient function that calls both `forward` and `backward`.
"""
self.forward(data_batch, is_train=True)
self.backward()
def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
score_end_callback=None,
reset=True, epoch=0):
"""Run prediction on `eval_data` and evaluate the performance according to
`eval_metric`.
Parameters
----------
eval_data : DataIter
eval_metric : EvalMetric
num_batch : int
Number of batches to run. Default is `None`, indicating run until the `DataIter`
finishes.
batch_end_callback : function
Could also be a list of functions.
reset : bool
Default `True`, indicating whether we should reset `eval_data` before starting
evaluating.
epoch : int
Default 0. For compatibility, this will be passed to callbacks (if any). During
training, this will correspond to the training epoch number.
Examples
--------
An example of using score for prediction::
>>> #Evaluate accuracy on val_dataiter
>>> metric = mx.metric.Accuracy()
>>> mod.score(val_dataiter, metric)
"""
assert self.binded and self.params_initialized
if reset:
eval_data.reset()
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
eval_metric.reset()
actual_num_batch = 0
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
self.forward(eval_batch, is_train=False)
self.update_metric(eval_metric, eval_batch.label)
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
actual_num_batch += 1
if score_end_callback:
params = BatchEndParam(epoch=epoch,
nbatch=actual_num_batch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(score_end_callback):
callback(params)
return eval_metric.get_name_value()
def iter_predict(self, eval_data, num_batch=None, reset=True):
"""Iterate over predictions.
for pred, i_batch, batch in module.iter_predict(eval_data):
# pred is a list of outputs from the module
# i_batch is a integer
# batch is the data batch from the data iterator
Parameters
----------
eval_data : DataIter
num_batch : int
Default is `None`, indicating running all the batches in the data iterator.
reset : bool
Default is `True`, indicating whether we should reset the data iter before start
doing prediction.
"""
assert self.binded and self.params_initialized
if reset:
eval_data.reset()
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
self.forward(eval_batch, is_train=False)
pad = eval_batch.pad
outputs = [out[0:out.shape[0]-pad] for out in self.get_outputs()]
yield (outputs, nbatch, eval_batch)
def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
always_output_list=False):
"""Run prediction and collect the outputs.
When `merge_batches` is `True` (by default), the return value will be a list
`[out1, out2, out3]`. Where each element is concatenation of the outputs for
all the mini-batches. If further that `always_output_list` is `False` (by default),
then in the case of a single output, `out1` is returned instead of `[out1]`.
When `merge_batches` is `False`, the return value will be a nested list like
`[[out1_batch1, out2_batch1], [out1_batch2], ...]`. This mode is useful because
in some cases (e.g. bucketing), the module does not necessarily produce the same
number of outputs.
The objects in the results are `NDArray`s. If you need to work with numpy array,
just call `.asnumpy()` on each of the `NDArray`.
Parameters
----------
eval_data : DataIter
num_batch : int
Default is `None`, indicating running all the batches in the data iterator.
merge_batches : bool
Default is `True`, see the doc for return values.
reset : bool
Default is `True`, indicating whether we should reset the data iter before start
doing prediction.
always_output_list : bool
Default is `False`, see the doc for return values.
Returns
-------
list of NDArray or list of list of NDArray
Predict results
Examples
--------
An example of using predict for prediction::
>>> #Predict on the first 10 batches of val_dataiter
>>> mod.predict(eval_data=val_dataiter, num_batch=10)
"""
assert self.binded and self.params_initialized
if reset:
eval_data.reset()
output_list = []
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
self.forward(eval_batch, is_train=False)
pad = eval_batch.pad
outputs = [out[0:out.shape[0]-pad].copy() for out in self.get_outputs()]
output_list.append(outputs)
if len(output_list) == 0:
return output_list
if merge_batches:
num_outputs = len(output_list[0])
for out in output_list:
assert len(out) == num_outputs, \
'Cannot merge batches, as num of outputs is not the same ' + \
'in mini-batches. Maybe bucketing is used?'
output_list2 = [ndarray.concatenate([out[i] for out in output_list])
for i in range(num_outputs)]
if num_outputs == 1 and not always_output_list:
return output_list2[0]
return output_list2
return output_list
def fit(self, train_data, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local',
optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
eval_end_callback=None,
eval_batch_end_callback=None, initializer=Uniform(0.01),
arg_params=None, aux_params=None, allow_missing=False,
force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
validation_metric=None, monitor=None):
"""Train the module parameters.
Parameters
----------
train_data : DataIter
eval_data : DataIter
If not `None`, will be used as validation set and evaluate the performance
after each epoch.
eval_metric : str or EvalMetric
Default `'acc'`. The performance measure used to display during training.
epoch_end_callback : function or list of function
Each callback will be called with the current `epoch`, `symbol`, `arg_params`
and `aux_params`.
batch_end_callback : function or list of function
Each callback will be called with a `BatchEndParam`.
kvstore : str or KVStore
Default `'local'`.
optimizer : str or Optimizer
Default `'sgd'`
optimizer_params : dict
Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
The default value is not a `dict`, just to avoid pylint warning on dangerous
default values.
eval_end_callback : function or list of function
These will be called at the end of each full evaluation, with the metrics over
the entire evaluation set.
eval_batch_end_callback : function or list of function
These will be called at the end of each minibatch during evaluation
initializer : Initializer
Will be called to initialize the module parameters if not already initialized.
arg_params : dict
Default `None`, if not `None`, should be existing parameters from a trained
model or loaded from a checkpoint (previously saved model). In this case,
the value here will be used to initialize the module parameters, unless they
are already initialized by the user via a call to `init_params` or `fit`.
`arg_params` has higher priority to `initializer`.
aux_params : dict
Default `None`. Similar to `arg_params`, except for auxiliary states.
allow_missing : bool
Default `False`. Indicate whether we allow missing parameters when `arg_params`
and `aux_params` are not `None`. If this is `True`, then the missing parameters
will be initialized via the `initializer`.
force_rebind : bool
Default `False`. Whether to force rebinding the executors if already binded.
force_init : bool
Default `False`. Indicate whether we should force initialization even if the
parameters are already initialized.
begin_epoch : int
Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
checkpoint saved at a previous training phase at epoch N, then we should specify
this value as N+1.
num_epoch : int
Number of epochs to run training.
Examples
--------
An example of using fit for training::
>>> #Assume training dataIter and validation dataIter are ready
>>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
num_epoch=10)
"""
assert num_epoch is not None, 'please specify number of epochs'
self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
for_training=True, force_rebind=force_rebind)
if monitor is not None:
self.install_monitor(monitor)
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
optimizer_params=optimizer_params)
if validation_metric is None:
validation_metric = eval_metric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
################################################################################
# training loop
################################################################################
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
for nbatch, data_batch in enumerate(train_data):
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch)
self.update()
self.update_metric(eval_metric, data_batch.label)
if monitor is not None:
monitor.toc_print()
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
# one epoch of training is finished
for name, val in eval_metric.get_name_value():
self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
toc = time.time()
self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)
if epoch_end_callback is not None:
for callback in _as_list(epoch_end_callback):
callback(epoch, self.symbol, arg_params, aux_params)
#----------------------------------------
# evaluation on validation set
if eval_data:
res = self.score(eval_data, validation_metric,
score_end_callback=eval_end_callback,
batch_end_callback=eval_batch_end_callback, epoch=epoch)
#TODO: pull this into default
for name, val in res:
self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
# end of 1 epoch, reset the data-iter for another epoch
train_data.reset()
################################################################################
# Symbol information
################################################################################
@property
def data_names(self):
"""A list of names for data required by this module."""
raise NotImplementedError()
@property
def output_names(self):
"""A list of names for the outputs of this module."""
raise NotImplementedError()
################################################################################
# Input/Output information
################################################################################
@property
def data_shapes(self):
"""A list of (name, shape) pairs specifying the data inputs to this module."""
raise NotImplementedError()
@property
def label_shapes(self):
"""A list of (name, shape) pairs specifying the label inputs to this module.
If this module does not accept labels -- either it is a module without loss
function, or it is not binded for training, then this should return an empty
list `[]`.
"""
raise NotImplementedError()
@property
def output_shapes(self):
"""A list of (name, shape) pairs specifying the outputs of this module."""
raise NotImplementedError()
################################################################################
# Parameters of a module
################################################################################
def get_params(self):
"""Get parameters, those are potentially copies of the the actual parameters used
to do computation on the device.
Returns
-------
`(arg_params, aux_params)`
a pair of dictionary of name to value mapping.
Examples
--------
An example of getting module parameters::
>>> print mod.get_params()
({'fc2_weight': <NDArray 64x128 @cpu(0)>, 'fc1_weight': <NDArray 128x100 @cpu(0)>,
'fc3_bias': <NDArray 10 @cpu(0)>, 'fc3_weight': <NDArray 10x64 @cpu(0)>,
'fc2_bias': <NDArray 64 @cpu(0)>, 'fc1_bias': <NDArray 128 @cpu(0)>}, {})
"""
raise NotImplementedError()
def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False):
"""Initialize the parameters and auxiliary states.
Parameters
----------
initializer : Initializer
Called to initialize parameters if needed.
arg_params : dict
If not None, should be a dictionary of existing arg_params. Initialization
will be copied from that.
aux_params : dict
If not None, should be a dictionary of existing aux_params. Initialization
will be copied from that.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
Examples
--------
An example of initializing module parameters::
>>> mod.init_params()
"""
raise NotImplementedError()
def set_params(self, arg_params, aux_params, allow_missing=False, force_init=True):
"""Assign parameter and aux state values.
Parameters
----------
arg_params : dict
Dictionary of name to value (`NDArray`) mapping.
aux_params : dict
Dictionary of name to value (`NDArray`) mapping.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
Examples
--------
An example of setting module parameters::
>>> sym, arg_params, aux_params = \
>>> mx.model.load_checkpoint(model_prefix, n_epoch_load)
>>> mod.set_params(arg_params=arg_params, aux_params=aux_params)
"""
self.init_params(initializer=None, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
def save_params(self, fname):
"""Save model parameters to file.
Parameters
----------
fname : str
Path to output param file.
Examples
--------
An example of saving module parameters::
>>> mod.save_params('myfile')
"""
arg_params, aux_params = self.get_params()
save_dict = {('arg:%s' % k) : v.as_in_context(cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k) : v.as_in_context(cpu()) for k, v in aux_params.items()})
ndarray.save(fname, save_dict)
def load_params(self, fname):
"""Load model parameters from file.
Parameters
----------
fname : str
Path to input param file.
Examples
--------
An example of loading module parameters
>>> mod.load_params('myfile')
"""
save_dict = ndarray.load(fname)
arg_params = {}
aux_params = {}
for k, value in save_dict.items():
arg_type, name = k.split(':', 1)
if arg_type == 'arg':
arg_params[name] = value
elif arg_type == 'aux':
aux_params[name] = value
else:
raise ValueError("Invalid param file " + fname)
self.set_params(arg_params, aux_params)
def get_states(self, merge_multi_context=True):
"""Get states from all devices
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are `NDArray`.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the states
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArray or list of list of NDArray
States
"""
assert self.binded and self.params_initialized
assert not merge_multi_context
return []
def set_states(self, states=None, value=None):
"""Set value for states. Only one of states & value can be specified.
Parameters
----------
states : list of list of NDArrays
source states arrays formatted like [[state1_dev1, state1_dev2],
[state2_dev1, state2_dev2]].
value : number
a single scalar value for all state arrays.
"""
assert self.binded and self.params_initialized
assert not states and not value
def install_monitor(self, mon):
"""Install monitor on all executors"""
raise NotImplementedError()
################################################################################
# Computations
################################################################################
def forward(self, data_batch, is_train=None):
"""Forward computation.
Parameters
----------
data_batch : DataBatch
Could be anything with similar API implemented.
is_train : bool
Default is `None`, which means `is_train` takes the value of `self.for_training`.
Examples
--------
An example of forward computation::
>>> from collections import namedtuple
>>> Batch = namedtuple('Batch', ['data'])
>>> mod.bind(data_shapes=[('data', (1, 10, 10))])
>>> mod.init_params()
>>> data1 = [mx.nd.ones([1, 10, 10])]
>>> mod.forward(Batch(data1))
>>> print mod.get_outputs()[0].asnumpy()
[[ 0.09999977 0.10000153 0.10000716 0.10000195 0.09999853 0.09999743
0.10000272 0.10000113 0.09999088 0.09999888]]
"""
raise NotImplementedError()
def backward(self, out_grads=None):
"""Backward computation.
Parameters
----------
out_grads : NDArray or list of NDArray, optional
Gradient on the outputs to be propagated back.
This parameter is only needed when bind is called
on outputs that are not a loss function.
Examples
--------
An example of backward computation::
>>> mod.backward()
>>> print mod.get_input_grads()[0].asnumpy()
[[[ 1.10182791e-05 5.12257748e-06 4.01927764e-06 8.32566820e-06
-1.59775993e-06 7.24269375e-06 7.28067835e-06 -1.65902311e-05
5.46342608e-06 8.44196393e-07]
...]]
"""
raise NotImplementedError()
def get_outputs(self, merge_multi_context=True):
"""Get outputs of the previous forward computation.
If `merge_multi_context` is `True`, it is like `[out1, out2]`. Otherwise, it
is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output
elements are `NDArray`. When `merge_multi_context` is `False`, those `NDArray`
might live on different devices.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the outputs
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArray or list of list of NDArray
Output
Examples
--------
An example of getting forward output::
>>> print mod.get_outputs()[0].asnumpy()
[[ 0.09999977 0.10000153 0.10000716 0.10000195 0.09999853 0.09999743
0.10000272 0.10000113 0.09999088 0.09999888]]
"""
raise NotImplementedError()
def get_input_grads(self, merge_multi_context=True):
"""Get the gradients to the inputs, computed in the previous backward computation.
If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it
is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output
elements are `NDArray`. When `merge_multi_context` is `False`, those `NDArray`
might live on different devices.
Parameters
----------
merge_multi_context : bool
Default is `True`. In the case when data-parallelism is used, the gradients
will be collected from multiple devices. A `True` value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
list of NDArray or list of list of NDArray
Input gradients
Examples
--------
An example of getting input gradients::
>>> print mod.get_input_grads()[0].asnumpy()
[[[ 1.10182791e-05 5.12257748e-06 4.01927764e-06 8.32566820e-06
-1.59775993e-06 7.24269375e-06 7.28067835e-06 -1.65902311e-05
5.46342608e-06 8.44196393e-07]
...]]
"""
raise NotImplementedError()
def update(self):
"""Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
Examples
--------
An example of updating module parameters::
>>> mod.init_optimizer(kvstore='local', optimizer='sgd',
>>> optimizer_params=(('learning_rate', 0.01), ))
>>> mod.backward()
>>> mod.update()
>>> print mod.get_params()[0]['fc3_weight'].asnumpy()
[[ 5.86930104e-03 5.28078526e-03 -8.88729654e-03 -1.08308345e-03
6.13054074e-03 4.27560415e-03 1.53817423e-03 4.62131854e-03
4.69872449e-03 -2.42400169e-03 9.94111411e-04 1.12386420e-03
...]]
"""
raise NotImplementedError()
def update_metric(self, eval_metric, labels):
"""Evaluate and accumulate evaluation metric on outputs of the last forward computation.
Parameters
----------
eval_metric : EvalMetric
labels : list of NDArray
Typically `data_batch.label`.
Examples
--------
An example of updating evaluation metric::
>>> mod.forward(data_batch)
>>> mod.update_metric(metric, data_batch.label)
"""
raise NotImplementedError()
################################################################################
# module setup
################################################################################
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None,
grad_req='write'):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Parameters
----------
data_shapes : list of (str, tuple)
Typically is `data_iter.provide_data`.
label_shapes : list of (str, tuple)
Typically is `data_iter.provide_label`.
for_training : bool
Default is `True`. Whether the executors should be bind for training.
inputs_need_grad : bool
Default is `False`. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
force_rebind : bool
Default is `False`. This function does nothing if the executors are already
binded. But with this `True`, the executors will be forced to rebind.
shared_module : Module
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
grad_req : str, list of str, dict of str to str
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(default to 'write').
Can be specified globally (str) or for each argument (list, dict).
Examples
--------
An example of binding symbols::
>>> mod.bind(data_shapes=[('data', (1, 10, 10))])
"""
raise NotImplementedError()
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False):
"""Install and initialize optimizers.
Parameters
----------
kvstore : str or KVStore
Default `'local'`.
optimizer : str or Optimizer
Default `'sgd'`
optimizer_params : dict
Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
just to avoid pylint warning of dangerous default values.
force_init : bool
Default `False`, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
Examples
--------
An example of initializing optimizer::
>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.005),))
"""
raise NotImplementedError()
################################################################################
# misc
################################################################################
@property
def symbol(self):
"""Get the symbol associated with this module.
Except for `Module`, for other types of modules (e.g. `BucketingModule`), this
property might not be a constant throughout its life time. Some modules might
not even be associated with any symbols.
"""
return self._symbol