.. currentmodule:: mxnet.module
The module API, defined in the module
(or simply mod
) package, provides an intermediate and high-level interface for performing computation with a Symbol
or Loss
. One can roughly think a module is a machine which can execute a program defined by a Symbol
or Loss
.
The class module.Module
is a commonly used module, which accepts a Symbol
or Loss
as the input:
data = mx.symbol.Variable('data') label = mx.sym.Variable('label') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10) loss = mx.loss.softmax_cross_entropy_loss(fc2, label) mod = mx.mod.Module(loss, data_names=('data',))
Alternatively, if you only want to do prediction or want to compute loss manually outside of module and feed gradient back using Module.backward(out_grads=...), you can also directly feed a Symbol
into module:
data = mx.symbol.Variable('data') label = mx.sym.Variable('label') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=10) mod = mx.mod.Module(fc2, data_names=('data',))
Assume there is a valid MXNet data iterator data
. We can initialize the module:
# allocate memory by given input shapes mod.bind(data_shapes=data.provide_data, label_shapes=data.provide_label) # initial parameters with uniform distribution in [-0.05, 0.05] mod.init_params(mx.init.Uniform(0.05))
Now the module is able to compute. We can call high-level API to train and predict:
mod.fit(data, num_epoch=10, ...) # train mod.predict(new_data) # predict on new data
or use intermediate APIs to perform step-by-step computations
for data_batch in data: mod.forward(data_batch) # forward on the provided data batch mod.backward() # backward to calculate the gradients mod.update() # update parameters using the default optimizer
A detailed tutorial is available at http://mxnet.io/tutorials/python/module.html.
.. note:: ``module`` is used to replace ``model``, which has been deprecated.
The module
package provides several modules:
.. autosummary:: :nosignatures: BaseModule Module SequentialModule BucketingModule PythonModule PythonLossModule
We summarize the interface for each class in the following sections.
BaseModule
classThe BaseModule
is the base class for all other module classes. It defines the interface each module class should provide.
.. autosummary:: :nosignatures: BaseModule.bind
.. autosummary:: :nosignatures: BaseModule.init_params BaseModule.set_params BaseModule.get_params BaseModule.save_params BaseModule.load_params
.. autosummary:: :nosignatures: BaseModule.fit BaseModule.score BaseModule.iter_predict BaseModule.predict
.. autosummary:: :nosignatures: BaseModule.forward BaseModule.backward BaseModule.forward_backward
.. autosummary:: :nosignatures: BaseModule.init_optimizer BaseModule.update BaseModule.update_metric
.. autosummary:: :nosignatures: BaseModule.data_names BaseModule.output_names BaseModule.data_shapes BaseModule.label_shapes BaseModule.output_shapes BaseModule.get_outputs BaseModule.get_input_grads
.. autosummary:: :nosignatures: BaseModule.get_states BaseModule.set_states BaseModule.install_monitor BaseModule.symbol
Besides the basic interface defined in BaseModule
, each module class supports additional functionality. We summarize them in this section.
Module
.. autosummary:: :nosignatures: Module.load Module.save_checkpoint Module.reshape Module.borrow_optimizer Module.save_optimizer_states Module.load_optimizer_states
BucketModule
.. autosummary:: :nosignatures: BucketModule.switch_bucket
SequentialModule
.. autosummary:: :nosignatures: SequentialModule.add
.. autoclass:: mxnet.module.BaseModule :members: .. autoclass:: mxnet.module.Module :members: .. autoclass:: mxnet.module.BucketingModule :members: .. autoclass:: mxnet.module.SequentialModule :members: .. autoclass:: mxnet.module.PythonModule :members: .. autoclass:: mxnet.module.PythonLossModule :members: