blob: 0d989a0465ec4932c713c75bc96a25a4b3dacf40 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnetexamples.rnn
import ml.dmlc.mxnet._
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
import scala.collection.JavaConverters._
/**
* Follows the demo, to test the char rnn:
* https://github.com/dmlc/mxnet/blob/master/example/rnn/char-rnn.ipynb
* @author Depeng Liang
*/
object TestCharRnn {
private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
def main(args: Array[String]): Unit = {
val stcr = new TestCharRnn
val parser: CmdLineParser = new CmdLineParser(stcr)
try {
parser.parseArgument(args.toList.asJava)
assert(stcr.dataPath != null && stcr.modelPrefix != null && stcr.starterSentence != null)
// The batch size for training
val batchSize = 32
// We can support various length input
// For this problem, we cut each input sentence to length of 129
// So we only need fix length bucket
val buckets = List(129)
// hidden unit in LSTM cell
val numHidden = 512
// embedding dimension, which is, map a char to a 256 dim vector
val numEmbed = 256
// number of lstm layer
val numLstmLayer = 3
// build char vocabluary from input
val vocab = Utils.buildVocab(stcr.dataPath)
// load from check-point
val (_, argParams, _) = Model.loadCheckpoint(stcr.modelPrefix, 75)
// build an inference model
val model = new RnnModel.LSTMInferenceModel(numLstmLayer, vocab.size + 1,
numHidden = numHidden, numEmbed = numEmbed,
numLabel = vocab.size + 1, argParams = argParams, dropout = 0.2f)
// generate a sequence of 1200 chars
val seqLength = 1200
val inputNdarray = NDArray.zeros(1)
val revertVocab = Utils.makeRevertVocab(vocab)
// Feel free to change the starter sentence
var output = stcr.starterSentence
val randomSample = true
var newSentence = true
val ignoreLength = output.length()
for (i <- 0 until seqLength) {
if (i <= ignoreLength - 1) Utils.makeInput(output(i), vocab, inputNdarray)
else Utils.makeInput(output.takeRight(1)(0), vocab, inputNdarray)
val prob = model.forward(inputNdarray, newSentence)
newSentence = false
val nextChar = Utils.makeOutput(prob, revertVocab, randomSample)
if (nextChar == "") newSentence = true
if (i >= ignoreLength) output = output ++ nextChar
}
// Let's see what we can learned from char in Obama's speech.
logger.info(output)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
parser.printUsage(System.err)
sys.exit(1)
}
}
}
}
class TestCharRnn {
@Option(name = "--data-path", usage = "the input train data file")
private val dataPath: String = "./data/obama.txt"
@Option(name = "--model-prefix", usage = "the model prefix")
private val modelPrefix: String = "./model/obama"
@Option(name = "--starter-sentence", usage = "the starter sentence")
private val starterSentence: String = "The joke"
}