blob: 0e7512ec5c7a672c8fb0ade5b80fb7b4bca807cf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet.module
import java.io.IOException
import ml.dmlc.mxnet.optimizer.SGD
import ml.dmlc.mxnet._
import org.slf4j.LoggerFactory
import scala.collection.mutable.ArrayBuffer
/**
* The base class of a modules. A module represents a computation component. The design
* purpose of a module is that it abstract a computation "machine", that one can run forward,
* backward, update parameters, etc. We aim to make the APIs easy to use, especially in the
* case when we need to use imperative API to work with multiple modules (e.g. stochastic
* depth network).
*
* A module has several states:
*
* - Initial state. Memory is not allocated yet, not ready for computation yet.
* - Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated,
* ready for computation.
* - Parameter initialized. For modules with parameters, doing computation before initializing
* the parameters might result in undefined outputs.
* - Optimizer installed. An optimizer can be installed to a module. After this, the parameters
* of the module can be updated according to the optimizer after gradients are computed
* (forward-backward).
*
* In order for a module to interactive with others, a module should be able to report the
* following information in its raw stage (before binded)
*
* - `data_names`: list of string indicating the names of required data.
* - `output_names`: list of string indicating the names of required outputs.
*
* And also the following richer information after binded:
*
* - state information
* - `binded`: `bool`, indicating whether the memory buffers needed for computation
* has been allocated.
* - `forTraining`: whether the module is binded for training (if binded).
* - `paramsInitialized`: `bool`, indicating whether the parameters of this modules
* has been initialized.
* - `optimizerInitialized`: `bool`, indicating whether an optimizer is defined
* and initialized.
* - `inputsNeedGrad`: `bool`, indicating whether gradients with respect to the
* input data is needed. Might be useful when implementing composition of modules.
*
* - input/output information
* - `dataShapes`: a list of `(name, shape)`. In theory, since the memory is allocated,
* we could directly provide the data arrays. But in the case of data parallelization,
* the data arrays might not be of the same shape as viewed from the external world.
* - `labelShapes`: a list of `(name, shape)`. This might be `[]` if the module does
* not need labels (e.g. it does not contains a loss function at the top), or a module
* is not binded for training.
* - `outputShapes`: a list of `(name, shape)` for outputs of the module.
*
* - parameters (for modules with parameters)
* - `getParams()`: return a tuple `(argParams, auxParams)`. Each of those
* is a dictionary of name to `NDArray` mapping. Those `NDArray` always lives on
* CPU. The actual parameters used for computing might live on other devices (GPUs),
* this function will retrieve (a copy of) the latest parameters. Therefore, modifying
* - `setParams(argParams, auxParams)`: assign parameters to the devices
* doing the computation.
* - `initParams(...)`: a more flexible interface to assign or initialize the parameters.
*
* - setup
* - `bind()`: prepare environment for computation.
* - `initOptimizer()`: install optimizer for parameter updating.
*
* - computation
* - `forward(dataBatch)`: forward operation.
* - `backward(outGrads=None)`: backward operation.
* - `update()`: update parameters according to installed optimizer.
* - `getOutputs()`: get outputs of the previous forward operation.
* - `getInputGrads()`: get the gradients with respect to the inputs computed
* in the previous backward operation.
* - `updateMetric(metric, labels)`: update performance metric for the previous forward
* computed results.
*
* - other properties (mostly for backward compatibility)
* - `symbol`: the underlying symbolic graph for this module (if any)
* This property is not necessarily constant. For example, for `BucketingModule`,
* this property is simply the *current* symbol being used. For other modules,
* this value might not be well defined.
*
* When those intermediate-level API are implemented properly, the following
* high-level API will be automatically available for a module:
*
* - `fit`: train the module parameters on a data set
* - `predict`: run prediction on a data set and collect outputs
* - `score`: run prediction on a data set and evaluate performance
*/
abstract class BaseModule {
private val logger = LoggerFactory.getLogger(classOf[BaseModule])
private[module] var binded: Boolean = false
private[module] var forTraining: Boolean = false
private[module] var inputsNeedGrad: Boolean = false
private[module] var paramsInitialized: Boolean = false
private[module] var optimizerInitialized: Boolean = false
private[module] var symbol: Symbol = null
private[module] var execGroup: DataParallelExecutorGroup = null
private[module] var argParams: Map[String, NDArray] = null
private[module] var auxParams: Map[String, NDArray] = null
// High Level API
def getSymbol: Symbol = this.symbol
// A convenient function that calls both `forward` and `backward`.
def forwardBackward(dataBatch: DataBatch): Unit = {
forward(dataBatch, isTrain = Option(true))
backward()
}
/**
* Run prediction on `eval_data` and evaluate the performance according to `eval_metric`.
* @param evalData : DataIter
* @param evalMetric : EvalMetric
* @param numBatch Number of batches to run. Default is `Integer.MAX_VALUE`,
* indicating run until the `DataIter` finishes.
* @param batchEndCallback Could also be a list of functions.
* @param reset Default `True`,
* indicating whether we should reset `eval_data` before starting evaluating.
* @param epoch Default 0. For compatibility, this will be passed to callbacks (if any).
* During training, this will correspond to the training epoch number.
*/
def score(evalData: DataIter, evalMetric: EvalMetric,
numBatch: Int = Integer.MAX_VALUE,
batchEndCallback: Option[BatchEndCallback] = None,
scoreEndCallback: Option[BatchEndCallback] = None,
reset: Boolean = true, epoch: Int = 0): EvalMetric = {
require(evalData != null && evalMetric != null)
require(binded && paramsInitialized)
if (reset) {
evalData.reset()
}
evalMetric.reset()
var nBatch = 0
while (evalData.hasNext && nBatch < numBatch) {
val evalBatch = evalData.next()
forward(evalBatch, isTrain = Option(false))
updateMetric(evalMetric, evalBatch.label)
batchEndCallback.foreach(callback => {
callback.invoke(epoch, nBatch, evalMetric)
})
evalBatch.dispose()
nBatch += 1
}
scoreEndCallback.foreach(callback => {
callback.invoke(epoch, nBatch, evalMetric)
})
evalMetric
}
/**
* Run prediction and collect the outputs.
* @param evalData
* @param numBatch Default is -1, indicating running all the batches in the data iterator.
* @param reset Default is `True`, indicating whether we should reset the data iter before start
* doing prediction.
* @return The return value will be a nested list like
* `[[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]`
* This mode is useful because in some cases (e.g. bucketing),
* the module does not necessarily produce the same number of outputs.
*/
def predictEveryBatch(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
: IndexedSeq[IndexedSeq[NDArray]] = {
require(binded && paramsInitialized)
if (reset) {
evalData.reset()
}
val outputList = ArrayBuffer.empty[IndexedSeq[NDArray]]
var nBatch = 0
while (evalData.hasNext && nBatch != numBatch) {
val evalBatch = evalData.next()
outputList.append(predict(evalBatch))
evalBatch.dispose()
nBatch += 1
}
outputList
}
def predict(batch: DataBatch): IndexedSeq[NDArray] = {
require(binded && paramsInitialized)
forward(batch, isTrain = Option(false))
val pad = batch.pad
getOutputsMerged().map(out => {
val withoutPadding = out.slice(0, out.shape(0)-pad)
val copied = withoutPadding.copy()
withoutPadding.dispose()
copied
})
}
/**
* Run prediction and collect the outputs.
* @param evalData
* @param numBatch Default is -1, indicating running all the batches in the data iterator.
* @param reset Default is `True`, indicating whether we should reset the data iter before start
* doing prediction.
* @return The return value will be a list `[out1, out2, out3]`.
* Where each element is concatenation of the outputs for all the mini-batches.
*/
def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
: IndexedSeq[NDArray] = {
val outputBatches = predictEveryBatch(evalData, numBatch, reset)
val numOutputs = outputBatches.head.size
outputBatches.foreach(out =>
require(out.size == numOutputs,
"Cannot merge batches, as num of outputs is not the same in mini-batches." +
"Maybe bucketing is used?")
)
val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
outputBatches.foreach(_.foreach(_.dispose()))
concatenatedOutput
}
// Symbol information
// A list of names for data required by this module.
def dataNames: IndexedSeq[String]
// A list of names for the outputs of this module.
def outputNames: IndexedSeq[String]
// Input/Output information
// A list of (name, shape) pairs specifying the data inputs to this module.
def dataShapes: IndexedSeq[DataDesc]
/**
* A list of (name, shape) pairs specifying the label inputs to this module.
* If this module does not accept labels -- either it is a module without loss
* function, or it is not binded for training, then this should return an empty
* list `[]`.
*/
def labelShapes: IndexedSeq[DataDesc]
// A list of (name, shape) pairs specifying the outputs of this module.
def outputShapes: IndexedSeq[(String, Shape)]
// Parameters of a module
/**
* Get parameters, those are potentially copies of the the actual parameters used
* to do computation on the device.
* @return `(argParams, auxParams)`, a pair of dictionary of name to value mapping.
*/
def getParams: (Map[String, NDArray], Map[String, NDArray])
/**
* Initialize the parameters and auxiliary states.
* @param initializer : Initializer
* Called to initialize parameters if needed.
* argParams : dict
* If not None, should be a dictionary of existing arg_params. Initialization
* will be copied from that.
* auxParams : dict
* If not None, should be a dictionary of existing aux_params. Initialization
* will be copied from that.
* allowMissing : bool
* If true, params could contain missing values, and the initializer will be
* called to fill those missing params.
* forceInit : bool
* If true, will force re-initialize even if already initialized.
* allowExtra : bool
* Whether allow extra parameters that are not needed by symbol.
* If this is True, no error will be thrown when argParams or auxParams
* contain extra parameters that is not needed by the executor.
*/
def initParams(initializer: Initializer = new Uniform(0.01f),
argParams: Map[String, NDArray] = null,
auxParams: Map[String, NDArray] = null,
allowMissing: Boolean = false,
forceInit: Boolean = false,
allowExtra: Boolean = false): Unit
/**
* Assign parameter and aux state values.
* argParams : dict
* Dictionary of name to value (`NDArray`) mapping.
* auxParams : dict
* Dictionary of name to value (`NDArray`) mapping.
* allowMissing : bool
* If true, params could contain missing values, and the initializer will be
* called to fill those missing params.
* forceInit : bool
* If true, will force re-initialize even if already initialized.
* allowExtra : bool
* Whether allow extra parameters that are not needed by symbol.
* If this is True, no error will be thrown when argParams or auxParams
* contain extra parameters that is not needed by the executor.
*/
def setParams(argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
allowMissing: Boolean = false,
forceInit: Boolean = true,
allowExtra: Boolean = false): Unit = {
initParams(initializer = null, argParams, auxParams,
allowMissing, forceInit, allowExtra)
}
/**
* Save model parameters to file.
* @param fname Path to output param file.
*
*/
def saveParams(fname: String): Unit = {
val (argParams, auxParams) = getParams
val saveDict = (
argParams.map { case (k, v) => (s"arg:$k", v.asInContext(Context.cpu())) }
++ auxParams.map { case (k, v) => (s"aux:$k", v.asInContext(Context.cpu())) }
)
NDArray.save(fname, saveDict)
}
/**
* Load model parameters from file.
* @param fname Path to input param file.
* @throws IOException if param file is invalid
*/
@throws(classOf[IOException])
def loadParams(fname: String): Unit = {
val saveDict = NDArray.load(fname)
val argParams = scala.collection.mutable.HashMap.empty[String, NDArray]
val auxParams = scala.collection.mutable.HashMap.empty[String, NDArray]
(saveDict._1 zip saveDict._2) foreach { case (key, value) =>
key.split(":", 2) match {
case Array(argType, name) if argType == "arg" => argParams.put(name, value)
case Array(argType, name) if argType == "aux" => auxParams.put(name, value)
case _ => throw new IOException("Invalid param file " + fname)
}
}
setParams(argParams.toMap, auxParams.toMap)
}
/**
*
* Train the module parameters.
* @param trainData
* @param evalData If not `None`, will be used as validation set and evaluate
* the performance after each epoch.
* @param numEpoch Number of epochs to run training.
* @param fitParams Extra parameters for training.
*/
def fit(trainData: DataIter, evalData: Option[DataIter] = None, numEpoch: Int = 1,
fitParams: FitParams = new FitParams): Unit = {
require(fitParams != null)
require(numEpoch > 0, "please specify number of epochs")
import ml.dmlc.mxnet.DataDesc._
bind(dataShapes = trainData.provideData, labelShapes = Option(trainData.provideLabel),
forTraining = true, forceRebind = fitParams.forceRebind)
fitParams.monitor.foreach(installMonitor)
initParams(fitParams.initializer, argParams, auxParams,
fitParams.allowMissing, fitParams.forceInit)
initOptimizer(fitParams.kvstore, fitParams.optimizer)
val valMetric = fitParams.validationMetric.getOrElse(fitParams.evalMetric)
// training loop
for (epoch <- fitParams.beginEpoch until numEpoch) {
val tic = System.currentTimeMillis
fitParams.evalMetric.reset()
var nBatch = 0
while (trainData.hasNext) {
val dataBatch = trainData.next()
fitParams.monitor.foreach(_.tic())
forwardBackward(dataBatch)
update()
updateMetric(fitParams.evalMetric, dataBatch.label)
fitParams.monitor.foreach(_.tocPrint())
fitParams.batchEndCallback.foreach(callback =>
callback.invoke(epoch, nBatch, fitParams.evalMetric)
)
dataBatch.dispose()
nBatch += 1
}
// one epoch of training is finished
val (name, value) = fitParams.evalMetric.get
logger.info(s"Epoch[$epoch] Train-$name=$value")
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
// sync aux params across devices
val (argParamsSync, auxParamsSync) = getParams
setParams(argParamsSync, auxParamsSync)
fitParams.epochEndCallback.foreach(callback =>
callback.invoke(epoch, symbol, argParamsSync, auxParamsSync)
)
// evaluation on validation set
evalData.foreach(data => {
val res = score(data, valMetric,
scoreEndCallback = fitParams.evalEndCallback,
batchEndCallback = fitParams.evalBatchEndCallback, epoch = epoch)
val (name, value) = res.get
logger.info(s"Epoch[$epoch] Validation-$name=$value")
})
// end of 1 epoch, reset the data-iter for another epoch
trainData.reset()
}
}
// Install monitor on all executors
def installMonitor(monitor: Monitor): Unit
// Computations
/**
* Forward computation.
* @param dataBatch Could be anything with similar API implemented.
* @param isTrain Default is `None`, which means `isTrain` takes the value of `this.forTraining`.
*/
def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit
/**
* Backward computation.
* @param outGrads Gradient on the outputs to be propagated back.
* This parameter is only needed when bind is called
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray] = null): Unit
/**
* Get outputs of the previous forward computation.
* @return In the case when data-parallelism is used,
* the outputs will be merged from multiple devices,
* as they look like from a single executor.
* The results will look like `[out1, out2]`
*/
def getOutputsMerged(): IndexedSeq[NDArray]
/**
* Get outputs of the previous forward computation.
* @return In the case when data-parallelism is used,
* the outputs will be collected from multiple devices.
* The results will look like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`,
* those `NDArray` might live on different devices.
*/
def getOutputs(): IndexedSeq[IndexedSeq[NDArray]]
/**
* Get the gradients to the inputs, computed in the previous backward computation.
* @return In the case when data-parallelism is used,
* the grads will be merged from multiple devices,
* as they look like from a single executor.
* The results will look like `[grad1, grad2]`
*/
def getInputGradsMerged(): IndexedSeq[NDArray]
/**
* Get the gradients to the inputs, computed in the previous backward computation.
* @return In the case when data-parallelism is used,
* the grads will be collected from multiple devices.
* The results will look like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`,
* those `NDArray` might live on different devices.
*/
def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]]
// Update parameters according to the installed optimizer and the gradients computed
// in the previous forward-backward batch.
def update(): Unit
/**
* Evaluate and accumulate evaluation metric on outputs of the last forward computation.
* @param evalMetric
* @param labels Typically `DataBatch.label`.
*/
def updateMetric(evalMetric: EvalMetric, labels: IndexedSeq[NDArray]): Unit
// module setup
/**
* Bind the symbols to construct executors.
* This is necessary before one can perform computation with the module.
* @param dataShapes Typically is `DataIter.provideData`.
* @param labelShapes Typically is `DataIter.provideLabel`.
* @param forTraining Default is `True`. Whether the executors should be bind for training.
* @param inputsNeedGrad Default is `False`.
* Whether the gradients to the input data need to be computed.
* Typically this is not needed.
* But this might be needed when implementing composition of modules.
* @param forceRebind Default is `False`. This function does nothing
* if the executors are already binded. But with this `True`,
* the executors will be forced to rebind.
* @param sharedModule Default is `None`. This is used in bucketing. When not `None`,
* the shared module essentially corresponds to a different bucket
* -- a module with different symbol but with the same sets of parameters
* (e.g. unrolled RNNs with different lengths).
* @param gradReq Requirement for gradient accumulation (globally).
* Can be 'write', 'add', or 'null' (default to 'write').
*/
def bind(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]] = None,
forTraining: Boolean = true, inputsNeedGrad: Boolean = false,
forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None,
gradReq: String = "write"): Unit
// Install and initialize optimizers.
def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit
}
class FitParams {
private[module] var evalMetric: EvalMetric = new Accuracy()
private[module] var epochEndCallback: Option[EpochEndCallback] = None
private[module] var batchEndCallback: Option[BatchEndCallback] = None
private[module] var kvstore: String = "local"
private[module] var optimizer: Optimizer = new SGD()
private[module] var evalEndCallback: Option[BatchEndCallback] = None
private[module] var evalBatchEndCallback: Option[BatchEndCallback] = None
private[module] var initializer: Initializer = new Uniform(0.01f)
private[module] var argParams: Map[String, NDArray] = null
private[module] var auxParams: Map[String, NDArray] = null
private[module] var allowMissing: Boolean = false
private[module] var forceRebind: Boolean = false
private[module] var forceInit: Boolean = false
private[module] var beginEpoch: Int = 0
private[module] var validationMetric: Option[EvalMetric] = None
private[module] var monitor: Option[Monitor] = None
// The performance measure used to display during training.
def setEvalMetric(evalMetric: EvalMetric): FitParams = {
require(evalMetric != null)
this.evalMetric = evalMetric
this
}
// Each callback will be called with the current
// `epoch`, `symbol`, `arg_params` and `aux_params`.
def setEpochEndCallback(epochEndCallback: EpochEndCallback): FitParams = {
this.epochEndCallback = Option(epochEndCallback)
this
}
// Each callback will be called with a `BatchEndParam`.
def setBatchEndCallback(batchEndCallback: BatchEndCallback): FitParams = {
this.batchEndCallback = Option(batchEndCallback)
this
}
def setKVStore(kvStore: String): FitParams = {
require(kvStore != null)
this.kvstore = kvstore
this
}
def setOptimizer(optimizer: Optimizer): FitParams = {
require(optimizer != null)
this.optimizer = optimizer
this
}
// These will be called at the end of each full evaluation,
// with the metrics over the entire evaluation set.
def setEvalEndCallback(evalEndCallback: BatchEndCallback): FitParams = {
this.evalEndCallback = Option(evalEndCallback)
this
}
// These will be called at the end of each minibatch during evaluation.
def setEvalBatchEndCallback(evalBatchEndCallback: BatchEndCallback): FitParams = {
this.evalBatchEndCallback = Option(evalBatchEndCallback)
this
}
// Will be called to initialize the module parameters if not already initialized.
def setInitializer(initializer: Initializer): FitParams = {
require(initializer != null)
this.initializer = initializer
this
}
// Default `None`, if not `None`, should be existing parameters from a trained
// model or loaded from a checkpoint (previously saved model). In this case,
// the value here will be used to initialize the module parameters,
// unless they are already initialized by the user
// via a call to `init_params` or `fit`.
// `argParams` has higher priority to `initializer`.
def setArgParams(argParams: Map[String, NDArray]): FitParams = {
this.argParams = argParams
this
}
// Default `None`. Similar to `argParams`, except for auxiliary states.
def setAuxParams(auxParams: Map[String, NDArray]): FitParams = {
this.auxParams = auxParams
this
}
// Default `False`. Indicate whether we allow missing parameters
// when `arg_params` and `aux_params` are not `None`.
// If this is `True`, then the missing parameters will be
// initialized via the `initializer`.
def setAllowMissing(allowMissing: Boolean): FitParams = {
this.allowMissing = allowMissing
this
}
// Default `False`. Whether to force rebinding the executors if already binded.
def setForceRebind(forceRebind: Boolean): FitParams = {
this.forceRebind = forceRebind
this
}
// Default `False`. Indicate whether we should force initialization even if the
// parameters are already initialized.
def setForceInit(forceInit: Boolean): FitParams = {
this.forceInit = forceInit
this
}
// Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
// checkpoint saved at a previous training phase at epoch N,
// then we should specify this value as N+1.
def setBeginEpoch(beginEpoch: Int): FitParams = {
require(beginEpoch >= 0)
this.beginEpoch = beginEpoch
this
}
def setValidationMetric(metric: EvalMetric): FitParams = {
this.validationMetric = Option(metric)
this
}
def setMonitor(monitor: Monitor): FitParams = {
this.monitor = Option(monitor)
this
}
}