blob: 1f2bc7ae2ca5aad67b024b0a3cb9964047574c27 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnetexamples.imclassification
import ml.dmlc.mxnet._
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
import scala.collection.mutable
import scala.collection.JavaConverters._
object TrainMnist {
private val logger = LoggerFactory.getLogger(classOf[TrainMnist])
// multi-layer perceptron
def getMlp: Symbol = {
val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))
mlp
}
// LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
// Haffner. "Gradient-based learning applied to document recognition."
// Proceedings of the IEEE (1998)
def getLenet: Symbol = {
val data = Symbol.Variable("data")
// first conv
val conv1 = Symbol.Convolution()()(
Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
"kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
// second conv
val conv2 = Symbol.Convolution()()(
Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
"kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
// first fullc
val flatten = Symbol.Flatten()()(Map("data" -> pool2))
val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500))
val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
// second fullc
val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10))
// loss
val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
lenet
}
def getIterator(dataShape: Shape)
(dataDir: String, batchSize: Int, kv: KVStore): (DataIter, DataIter) = {
val flat = if (dataShape.size == 3) "False" else "True"
val train = IO.MNISTIter(Map(
"image" -> (dataDir + "train-images-idx3-ubyte"),
"label" -> (dataDir + "train-labels-idx1-ubyte"),
"label_name" -> "softmax_label",
"input_shape" -> dataShape.toString,
"batch_size" -> batchSize.toString,
"shuffle" -> "True",
"flat" -> flat,
"num_parts" -> kv.numWorkers.toString,
"part_index" -> kv.`rank`.toString))
val eval = IO.MNISTIter(Map(
"image" -> (dataDir + "t10k-images-idx3-ubyte"),
"label" -> (dataDir + "t10k-labels-idx1-ubyte"),
"label_name" -> "softmax_label",
"input_shape" -> dataShape.toString,
"batch_size" -> batchSize.toString,
"flat" -> flat,
"num_parts" -> kv.numWorkers.toString,
"part_index" -> kv.`rank`.toString))
(train, eval)
}
def main(args: Array[String]): Unit = {
val inst = new TrainMnist
val parser: CmdLineParser = new CmdLineParser(inst)
try {
parser.parseArgument(args.toList.asJava)
val dataPath = if (inst.dataDir == null) System.getenv("MXNET_DATA_DIR")
else inst.dataDir
val (dataShape, net) =
if (inst.network == "mlp") (Shape(784), getMlp)
else (Shape(1, 28, 28), getLenet)
val devs =
if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))
val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
envs.put("DMLC_ROLE", inst.role)
if (inst.schedulerHost != null) {
require(inst.schedulerPort > 0, "scheduler port not specified")
envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
require(inst.numWorker > 0, "Num of workers must > 0")
envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
require(inst.numServer > 0, "Num of servers must > 0")
envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
logger.info("Init PS environments")
KVStoreServer.init(envs.toMap)
}
if (inst.role != "worker") {
logger.info("Start KVStoreServer for scheduler & servers")
KVStoreServer.start()
} else {
ModelTrain.fit(dataDir = inst.dataDir,
batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
network = net, dataLoader = getIterator(dataShape),
kvStore = inst.kvStore, numEpochs = inst.numEpochs,
modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch,
monitorSize = inst.monitor)
logger.info("Finish fit ...")
}
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
parser.printUsage(System.err)
sys.exit(1)
}
}
}
}
class TrainMnist {
@Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet']")
private val network: String = "mlp"
@Option(name = "--data-dir", usage = "the input data directory")
private val dataDir: String = "mnist/"
@Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'")
private val gpus: String = null
@Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'")
private val cpus: String = null
@Option(name = "--num-examples", usage = "the number of training examples")
private val numExamples: Int = 60000
@Option(name = "--batch-size", usage = "the batch size")
private val batchSize: Int = 128
@Option(name = "--lr", usage = "the initial learning rate")
private val lr: Float = 0.1f
@Option(name = "--model-prefix", usage = "the prefix of the model to load/save")
private val modelPrefix: String = null
@Option(name = "--num-epochs", usage = "the number of training epochs")
private val numEpochs = 10
@Option(name = "--load-epoch", usage = "load the model on an epoch using the model-prefix")
private val loadEpoch: Int = -1
@Option(name = "--kv-store", usage = "the kvstore type")
private val kvStore = "local"
@Option(name = "--lr-factor",
usage = "times the lr with a factor for every lr-factor-epoch epoch")
private val lrFactor: Float = 1f
@Option(name = "--lr-factor-epoch", usage = "the number of epoch to factor the lr, could be .5")
private val lrFactorEpoch: Float = 1f
@Option(name = "--monitor", usage = "monitor the training process every N batch")
private val monitor: Int = -1
@Option(name = "--role", usage = "scheduler/server/worker")
private val role: String = "worker"
@Option(name = "--scheduler-host", usage = "Scheduler hostname / ip address")
private val schedulerHost: String = null
@Option(name = "--scheduler-port", usage = "Scheduler port")
private val schedulerPort: Int = 0
@Option(name = "--num-worker", usage = "# of workers")
private val numWorker: Int = 1
@Option(name = "--num-server", usage = "# of servers")
private val numServer: Int = 1
}