The module API provides an intermediate and high-level interface for performing computation with neural networks in MXNet. A module is an instance of subclasses of the BaseModule
. The most widely used module class is called Module
. Module wraps a Symbol
and one or more Executors
. For a full list of functions, see BaseModule
. A subclass of modules might have extra interface functions. This topic provides some examples of common use cases. All of the module APIs are in the Module
namespace.
To construct a module, refer to the constructors for the module class. For example, the Module
class accepts a Symbol
as input:
import ml.dmlc.mxnet._ import ml.dmlc.mxnet.module.{FitParams, Module} // construct a simple MLP val data = Symbol.Variable("data") val fc1 = Symbol.FullyConnected(name = "fc1")(data)(Map("num_hidden" -> 128)) val act1 = Symbol.Activation(name = "relu1")(fc1)(Map("act_type" -> "relu")) val fc2 = Symbol.FullyConnected(name = "fc2")(act1)(Map("num_hidden" -> 64)) val act2 = Symbol.Activation(name = "relu2")(fc2)(Map("act_type" -> "relu")) val fc3 = Symbol.FullyConnected(name = "fc3")(act2)(Map("num_hidden" -> 10)) val out = Symbol.SoftmaxOutput(name = "softmax")(fc3)() // construct the module val mod = new Module(out)
By default, context
is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts.
Before you can compute with a module, you need to call bind()
to allocate the device memory and initParams()
or SetParams()
to initialize the parameters. If you simply want to fit a module, you don't need to call bind()
and initParams()
explicitly, because the fit() function automatically calls them if they are needed.
mod.bind(dataShapes = train_dataiter.provideData, labelShapes = Some(train_dataiter.provideLabel)) mod.initParams()
Now you can compute with the module using functions like forward()
, backward()
, etc.
Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the fit()
function with some DataIter
s:
import ml.dmlc.mxnet.optimizer.SGD val mod = new Module(softmax) mod.fit(train_dataiter, evalData = scala.Option(eval_dataiter), \ numEpoch = n_epoch, fitParams = new FitParams()\ .setOptimizer(new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f)))
The interface is very similar to the old FeedForward
class. You can pass in batch-end callbacks using setBatchEndCallback
and epoch-end callbacks using setEpochEndCallback
. You can also set parameters using methods like setOptimizer
and setEvalMetric
. To learn more about the FitParams()
, see the API page. To predict with a module, call predict()
with a DataIter
:
mod.predict(val_dataiter)
The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the predict()
function.
When prediction results might be too large to fit in memory, use the predictEveryBatch
API:
val preds = mod.predictEveryBatch(val_dataiter) val_dataiter.reset() var i = 0 while (val_dataiter.hasNext) { val batch = val_dataiter.next() val predLabel: Array[Int] = NDArray.argmax_channel(preds(i)(0)).toArray.map(_.toInt) val label = batch.label(0).toArray.map(_.toInt) //do something... i += 1 }
If you need to evaluate on a test set and don't need the prediction output, call the score()
function with a DataIter
and an EvalMetric
:
mod.score(val_dataiter, metric)
This runs predictions on each batch in the provided DataIter
and computes the evaluation score using the provided EvalMetric
. The evaluation results are stored in metric
so that you can query later.
To save the module parameters in each training epoch, use a checkpoint
callback:
val modelPrefix: String = "mymodel" for (epoch <- 0 until 5) { while(train_dataiter.hasNext){ // forward backward pass //do something... } val checkpoint = mod.saveCheckpoint(modelPrefix, epoch, saveOptStates = true) }
To load the saved module parameters, call the loadCheckpoint
function:
val mod = Module.loadCheckpoint(modelPrefix, loadModelEpoch, loadOptimizerStates = true)
To initialize parameters, Bind the symbols to construct executors first with bind
method. Then, initialize the parameters and auxiliary states by calling initParams()
method.
mod.bind(dataShapes = train_dataiter.provideData, labelShapes = Some(train_dataiter.provideLabel)) mod.initParams()
To get current parameters, use getParams
method.
val (argParams, auxParams) = mod.getParams
To assign parameter and aux state values, use setParams
method.
mod.setParams(argParams, auxParams)
To resume training from a saved checkpoint, instead of calling setParams()
, directly call fit()
, passing the loaded parameters, so that fit()
knows to start from those parameters instead of initializing randomly:
mod.fit(..., fitParams=new FitParams().setArgParams(argParams).\ setAuxParams(auxParams).setBeginEpoch(beginEpoch))
Create an object of the FitParams()
class, and then use it to call the setBeginEpoch()
method to pass beginEpoch
so that fit()
knows to resume from a saved epoch.