| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package ml.dmlc.mxnet |
| |
| import java.nio.ByteBuffer |
| |
| import org.slf4j.LoggerFactory |
| |
| import scala.collection.mutable |
| |
| /** |
| * Describe the model flow |
| */ |
| class Model |
| object Model { |
| private val logger = LoggerFactory.getLogger(classOf[Model]) |
| |
| /** |
| * Checkpoint the model data into file. |
| * @param prefix Prefix of model name. |
| * @param epoch The epoch number of the model. |
| * @param symbol The input symbol |
| * @param argParams Model parameter, dict of name to NDArray of net's weights. |
| * @param auxParams Model parameter, dict of name to NDArray of net's auxiliary states. |
| * @note |
| * - ``prefix-symbol.json`` will be saved for symbol. |
| * - ``prefix-epoch.params`` will be saved for parameters. |
| */ |
| def saveCheckpoint(prefix: String, epoch: Int, symbol: Symbol, |
| argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = { |
| symbol.save(s"$prefix-symbol.json") |
| val saveDict = argParams.map { case (k, v) => s"arg:$k" -> v } ++ |
| auxParams.map { case (k, v) => s"aux:$k" -> v } |
| val paramName = "%s-%04d.params".format(prefix, epoch) |
| NDArray.save(paramName, saveDict) |
| logger.info(s"Saved checkpoint to $paramName") |
| } |
| |
| /** |
| * Load model checkpoint from file. |
| * |
| * @param prefix Prefix of model name. |
| * @param epoch Epoch number of model we would like to load. |
| * |
| * @return |
| * symbol : The symbol configuration of computation network. |
| * argParams : Model parameter, dict of name to NDArray of net's weights. |
| * auxParams : Model parameter, dict of name to NDArray of net's auxiliary states. |
| * @note |
| * - symbol will be loaded from ``prefix-symbol.json``. |
| * - parameters will be loaded from ``prefix-epoch.params``. |
| */ |
| def loadCheckpoint(prefix: String, epoch: Int): |
| (Symbol, Map[String, NDArray], Map[String, NDArray]) = { |
| val symbol = Symbol.load(s"$prefix-symbol.json") |
| val saveDict = NDArray.load("%s-%04d.params".format(prefix, epoch)) |
| val argParams = mutable.HashMap[String, NDArray]() |
| val auxParams = mutable.HashMap[String, NDArray]() |
| for ((k, v) <- saveDict._1 zip saveDict._2) { |
| val splitted = k.split(":", 2) |
| val tp = splitted(0) |
| val name = splitted(1) |
| if (tp == "arg") { |
| argParams(name) = v |
| } else if (tp == "aux") { |
| auxParams(name) = v |
| } |
| } |
| (symbol, argParams.toMap, auxParams.toMap) |
| } |
| |
| // a helper class for serializing model |
| class SerializedModel private[mxnet] ( |
| val symbol: String, |
| val argParams: Map[String, Array[Byte]], |
| val auxParams: Map[String, Array[Byte]]) extends Serializable |
| |
| private[mxnet] def serialize(symbol: Symbol, |
| argParams: Map[String, NDArray], |
| auxParams: Map[String, NDArray]): Array[Byte] = { |
| val serializedModel = new SerializedModel( |
| symbol.toJson, |
| argParams.map { case (k, v) => (k, v.serialize()) }, |
| auxParams.map { case (k, v) => (k, v.serialize()) } |
| ) |
| Serializer.getSerializer.serialize(serializedModel).array() |
| } |
| |
| private[mxnet] def deserialize(bytes: Array[Byte]): |
| (Symbol, Map[String, NDArray], Map[String, NDArray]) = { |
| val model = Serializer.getSerializer.deserialize[SerializedModel](ByteBuffer.wrap(bytes)) |
| val symbol = Symbol.loadJson(model.symbol) |
| val argParams = model.argParams.map { case (k, v) => |
| (k, NDArray.deserialize(v)) |
| } |
| val auxParams = model.auxParams.map { case (k, v) => |
| (k, NDArray.deserialize(v)) |
| } |
| (symbol, argParams, auxParams) |
| } |
| |
| /** |
| * Create kvstore |
| * This function select and create a proper kvstore given the kvstore type |
| * @param kvStore KVStore type |
| * @param numDevice The number of devices |
| * @param argParams Model parameter, dict of name to NDArray of net's weights. |
| * @return Option of created [[KVStore]] and whether or not update weight on it |
| */ |
| private[mxnet] def createKVStore(kvStore: String, |
| numDevice: Int, |
| argParams: Map[String, NDArray]): (Option[KVStore], Boolean) = { |
| if (numDevice == 1 && !kvStore.contains("dist")) { |
| // no need to use kv for single device and single machine |
| (None, false) |
| } else { |
| var kvType = kvStore |
| if (kvType == "local") { |
| // automatically select a proper local |
| val maxSize = argParams.values.map(_.shape.product).max |
| kvType = |
| if (maxSize < 1024 * 1024 * 16) { |
| "local_update_cpu" |
| } else { |
| "local_allreduce_cpu" |
| } |
| logger.info(s"Auto - select kvstore type = $kvType") |
| } |
| (Option(KVStore.create(kvType)), !kvType.contains("local_allreduce")) |
| } |
| } |
| |
| /** |
| * Create a kvStore (wrap it with Option, None if given kvStore == null) |
| * @param kvStore KVStore |
| * @return Option of created [[KVStore]] and whether or not update weight on it |
| */ |
| private[mxnet] def createKVStore(kvStore: KVStore): (Option[KVStore], Boolean) = { |
| (Option(kvStore), kvStore != null && !kvStore.`type`.contains("local_allreduce")) |
| } |
| |
| // Initialize kvstore |
| private[mxnet] def initializeKVStore(kvStore: KVStore, |
| paramArrays: IndexedSeq[Array[NDArray]], |
| argParams: Map[String, NDArray], |
| paramNames: IndexedSeq[String], |
| updateOnKVStore: Boolean): Unit = { |
| require(paramArrays.length == paramNames.length) |
| for (idx <- 0 until paramArrays.length) { |
| val paramOnDevs = paramArrays(idx) |
| val name = paramNames(idx) |
| kvStore.init(name, argParams(paramNames(idx))) |
| if (updateOnKVStore) { |
| kvStore.pull(name, paramOnDevs, -idx) |
| } |
| } |
| } |
| |
| // Perform update of param_arrays from grad_arrays on kvstore |
| private[mxnet] def updateParamsOnKVStore(paramArrays: IndexedSeq[Array[NDArray]], |
| gradArrays: IndexedSeq[Array[NDArray]], |
| kvStore: Option[KVStore], |
| paramNames: IndexedSeq[String]): Unit = { |
| (paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) => |
| if (gradList != null) { |
| val name = paramNames(index) |
| // push gradient, priority is negative index |
| kvStore.foreach(_.push(name, gradList, -index)) |
| // pull back the weights |
| kvStore.foreach(_.pull(name, argList, -index)) |
| } |
| } |
| } |
| |
| // Perform update of param_arrays from grad_arrays not on kvstore |
| private[mxnet] def updateParams(paramArrays: IndexedSeq[Array[NDArray]], |
| gradArrays: IndexedSeq[Array[NDArray]], |
| updater: MXKVStoreUpdater, |
| numDevice: Int, |
| paramNames: IndexedSeq[String], |
| kvStore: Option[KVStore] = None) { |
| (paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) => |
| if (gradList != null) { |
| kvStore.foreach(kv => { |
| val name = paramNames(index) |
| // push gradient, priority is negative index |
| kv.push(name, gradList, -index) |
| // pull back the sum gradients, to the same locations. |
| kv.pull(name, gradList, -index) |
| }) |
| (argList zip gradList).zipWithIndex.foreach { case ((w: NDArray, g: NDArray), k: Int) => |
| // faked an index here, to make optimizer create diff |
| // state for the same index but on diff devs, |
| // (copy from python package) TODO(mli) use a better solution latter |
| updater.update(index * numDevice + k, g, w) |
| } |
| } |
| } |
| } |
| |
| /** |
| * Internal training function on multiple devices. |
| * This function will also work for single device as well. |
| * @param symbol The network configuration |
| * @param ctx The training devices. |
| * @param argNames Name of all arguments of the network. |
| * @param paramNames Name of all trainable parameters of the network. |
| * @param auxNames Name of all auxiliary states of the network. |
| * @param argParams Model parameter, dict of name to NDArray of net's weights. |
| * @param auxParams Model parameter, dict of name to NDArray of net's auxiliary states. |
| * @param beginEpoch The begining training epoch. |
| * @param endEpoch The end training epoch. |
| * @param epochSize Number of batches in a epoch. |
| * In default, it is set to ceil(num_train_examples / batch_size) |
| * @param optimizer The optimization algorithm |
| * @param kvStore The KVStore |
| * @param updateOnKVStore whether or not perform weight updating on kvstore |
| * @param trainData Training data iterator. |
| * @param evalData Validation data iterator. |
| * @param evalMetric A evaluation function. |
| * @param epochEndCallback A callback that is invoked at end of each epoch. |
| * This can be used to checkpoint model each epoch. |
| * @param batchEndCallback A callback that is invoked at end of each batch. |
| * This can be used to measure speed, |
| * get result from evaluation metric. etc. |
| * @param workLoadList The list of work load for different devices, in the same order as ctx |
| * @param monitor Monitor outputs, weights, and gradients for debugging |
| * @note This function will inplace update the NDArrays in argParams and auxStates. |
| */ |
| // scalastyle:off parameterNum |
| private[mxnet] def trainMultiDevice(symbol: Symbol, ctx: Array[Context], |
| argNames: IndexedSeq[String], paramNames: IndexedSeq[String], |
| auxNames: IndexedSeq[String], argParams: Map[String, NDArray], |
| auxParams: Map[String, NDArray], |
| beginEpoch: Int, endEpoch: Int, epochSize: Int, |
| optimizer: Optimizer, |
| kvStore: Option[KVStore], updateOnKVStore: Boolean, |
| trainData: DataIter, |
| evalData: Option[DataIter] = None, |
| evalMetric: EvalMetric, |
| epochEndCallback: Option[EpochEndCallback] = None, |
| batchEndCallback: Option[BatchEndCallback] = None, |
| workLoadList: Seq[Float] = Nil, |
| monitor: Option[Monitor] = None, |
| symGen: SymbolGenerator = null): Unit = { |
| val executorManager = new DataParallelExecutorManager( |
| symbol = symbol, |
| symGen = symGen, |
| ctx = ctx, |
| trainData = trainData, |
| paramNames = paramNames, |
| argNames = argNames, |
| auxNames = auxNames, |
| workLoadList = workLoadList) |
| |
| monitor.foreach(executorManager.installMonitor) |
| executorManager.setParams(argParams, auxParams) |
| |
| // updater for updateOnKVStore = false |
| val updaterLocal = Optimizer.getUpdater(optimizer) |
| |
| kvStore.foreach(initializeKVStore(_, executorManager.paramArrays, |
| argParams, executorManager.paramNames, updateOnKVStore)) |
| if (updateOnKVStore) { |
| kvStore.foreach(_.setOptimizer(optimizer)) |
| } |
| |
| // Now start training |
| for (epoch <- beginEpoch until endEpoch) { |
| // Training phase |
| val tic = System.currentTimeMillis |
| evalMetric.reset() |
| var nBatch = 0 |
| var epochDone = false |
| // Iterate over training data. |
| trainData.reset() |
| while (!epochDone) { |
| var doReset = true |
| while (doReset && trainData.hasNext) { |
| val dataBatch = trainData.next() |
| executorManager.loadDataBatch(dataBatch) |
| monitor.foreach(_.tic()) |
| executorManager.forward(isTrain = true) |
| executorManager.backward() |
| if (updateOnKVStore) { |
| updateParamsOnKVStore(executorManager.paramArrays, |
| executorManager.gradArrays, |
| kvStore, executorManager.paramNames) |
| } else { |
| updateParams(executorManager.paramArrays, |
| executorManager.gradArrays, |
| updaterLocal, ctx.length, |
| executorManager.paramNames, |
| kvStore) |
| } |
| monitor.foreach(_.tocPrint()) |
| // evaluate at end, so out_cpu_array can lazy copy |
| executorManager.updateMetric(evalMetric, dataBatch.label) |
| |
| nBatch += 1 |
| batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric)) |
| |
| // this epoch is done possibly earlier |
| if (epochSize != -1 && nBatch >= epochSize) { |
| doReset = false |
| } |
| } |
| if (doReset) { |
| trainData.reset() |
| } |
| |
| // this epoch is done |
| epochDone = (epochSize == -1 || nBatch >= epochSize) |
| } |
| |
| val (name, value) = evalMetric.get |
| name.zip(value).foreach { case (n, v) => |
| logger.info(s"Epoch[$epoch] Train-$n=$v") |
| } |
| val toc = System.currentTimeMillis |
| logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") |
| |
| evalData.foreach { evalDataIter => |
| evalMetric.reset() |
| evalDataIter.reset() |
| // TODO: make DataIter implement Iterator |
| while (evalDataIter.hasNext) { |
| val evalBatch = evalDataIter.next() |
| executorManager.loadDataBatch(evalBatch) |
| executorManager.forward(isTrain = false) |
| executorManager.updateMetric(evalMetric, evalBatch.label) |
| } |
| |
| val (name, value) = evalMetric.get |
| name.zip(value).foreach { case (n, v) => |
| logger.info(s"Epoch[$epoch] Train-$n=$v") |
| } |
| } |
| |
| if (epochEndCallback.isDefined || epoch + 1 == endEpoch) { |
| executorManager.copyTo(argParams, auxParams) |
| } |
| epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams)) |
| } |
| |
| updaterLocal.dispose() |
| executorManager.dispose() |
| } |
| // scalastyle:on parameterNum |
| } |
| |
| trait EpochEndCallback { |
| def invoke(epoch: Int, symbol: Symbol, |
| argParams: Map[String, NDArray], |
| auxStates: Map[String, NDArray]): Unit |
| } |
| |
| trait BatchEndCallback { |
| def invoke(epoch: Int, nBatch: Int, evalMetric: EvalMetric) |
| } |