blob: a49044bad268fa40f7a8bee2493c4871552edbfa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnetexamples.rnn
import ml.dmlc.mxnet.Callback.Speedometer
import ml.dmlc.mxnet._
import BucketIo.BucketSentenceIter
import ml.dmlc.mxnet.optimizer.SGD
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.JavaConverters._
/**
* Bucketing LSTM examples
* @author Yizhi Liu
*/
class LstmBucketing {
@Option(name = "--data-train", usage = "training set")
private val dataTrain: String = "example/rnn/ptb.train.txt"
@Option(name = "--data-val", usage = "validation set")
private val dataVal: String = "example/rnn/ptb.valid.txt"
@Option(name = "--num-epoch", usage = "the number of training epoch")
private val numEpoch: Int = 5
@Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'")
private val gpus: String = null
@Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'")
private val cpus: String = null
@Option(name = "--save-model-path", usage = "the model saving path")
private val saveModelPath: String = "model/lstm"
}
object LstmBucketing {
private val logger: Logger = LoggerFactory.getLogger(classOf[LstmBucketing])
def perplexity(label: NDArray, pred: NDArray): Float = {
val labelArr = label.T.toArray.map(_.toInt)
var loss = .0
(0 until pred.shape(0)).foreach(i =>
loss -= Math.log(Math.max(1e-10f, pred.slice(i).toArray(labelArr(i))))
)
Math.exp(loss / labelArr.length).toFloat
}
def main(args: Array[String]): Unit = {
val inst = new LstmBucketing
val parser: CmdLineParser = new CmdLineParser(inst)
try {
parser.parseArgument(args.toList.asJava)
val contexts =
if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))
val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
val numEmbed = 200
val numLstmLayer = 2
val learningRate = 0.01f
val momentum = 0.0f
logger.info("Building vocab ...")
val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)
class BucketSymGen extends SymbolGenerator {
override def generate(key: AnyRef): Symbol = {
val seqLen = key.asInstanceOf[Int]
Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
}
}
val initC = (0 until numLstmLayer).map(l =>
(s"l${l}_init_c", (batchSize, numHidden))
)
val initH = (0 until numLstmLayer).map(l =>
(s"l${l}_init_h", (batchSize, numHidden))
)
val initStates = initC ++ initH
val dataTrain = new BucketSentenceIter(inst.dataTrain, vocab,
buckets, batchSize, initStates)
val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
buckets, batchSize, initStates)
logger.info("Start training ...")
val model = FeedForward.newBuilder(new BucketSymGen())
.setContext(contexts)
.setNumEpoch(inst.numEpoch)
.setOptimizer(new SGD(learningRate = learningRate, momentum = momentum, wd = 0.00001f))
.setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
.setTrainData(dataTrain)
.setEvalData(dataVal)
.setEvalMetric(new CustomMetric(perplexity, name = "perplexity"))
.setBatchEndCallback(new Speedometer(batchSize, 50))
.build()
model.save(inst.saveModelPath)
} catch {
case ex: Exception =>
logger.error(ex.getMessage, ex)
parser.printUsage(System.err)
sys.exit(1)
}
}
}