blob: 11d20394407edec13b45bf9c0caf2ed27bbe5180 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Sherlock Holmes Language Model using lstm, you can replace mx.lstm by mx.gru/ mx.rnn to use gru/rnn model
# The data file can be found at:
# https://github.com/dmlc/web-data/tree/master/mxnet/sherlockholmes
require(hash)
require(mxnet)
require(stringr
)
load.data <- function(path, dic=NULL) {
fi <- file(path, "r")
content <- paste(readLines(fi), collapse="<eos>")
close(fi)
#cat(content)
content <- str_split(content, ' ')[[1]]
cat(paste0("Loading ", path, ", size of data = ", length(content), "\n"))
X <- array(0, dim=c(length(content)))
#cat(X)
if (is.null(dic))
dic <- hash()
idx <- 1
for (i in 1:length(content)) {
word <- content[i]
if (str_length(word) > 0) {
if (!has.key(word, dic)) {
dic[[word]] <- idx
idx <- idx + 1
}
X[i] <- dic[[word]]
}
}
cat(paste0("Unique token: ", length(dic), "\n"))
return (list(X=X, dic=dic))
}
replicate.data <- function(X, seq.len) {
num.seq <- as.integer(length(X) / seq.len)
X <- X[1:(num.seq*seq.len)]
print
dim(X) = c(seq.len, num.seq)
return (X)
}
drop.tail <- function(X, batch.size) {
shape <- dim(X)
nstep <- as.integer(shape[2] / batch.size)
return (X[, 1:(nstep * batch.size)])
}
get.label <- function(X) {
label <- array(0, dim=dim(X))
d <- dim(X)[1]
w <- dim(X)[2]
for (i in 0:(w-1)) {
for (j in 1:d) {
label[i*d+j] <- X[(i*d+j)%%(w*d)+1]
}
}
return (label)
}
batch.size = 20
seq.len = 35
num.hidden = 200
num.embed = 200
num.lstm.layer = 2
num.round = 15
learning.rate= 0.1
wd=0.00001
update.period = 1
train <- load.data("./data/sherlockholmes.train.txt")
X.train <- train$X
dic <- train$dic
val <- load.data("./data/sherlockholmes.valid.txt", dic)
X.val <- val$X
dic <- val$dic
X.train.data <- replicate.data(X.train, seq.len)
X.val.data <- replicate.data(X.val, seq.len)
vocab <- length(dic)
cat(paste0("Vocab=", vocab, "\n"))
X.train.data <- drop.tail(X.train.data, batch.size)
X.val.data <- drop.tail(X.val.data, batch.size)
X.train.label <- get.label(X.train.data)
X.val.label <- get.label(X.val.data)
X.train <- list(data=X.train.data, label=X.train.label)
X.val <- list(data=X.val.data, label=X.val.label)
model <- mx.lstm(X.train, X.val,
ctx=mx.gpu(0),
num.round=num.round,
update.period=update.period,
num.lstm.layer=num.lstm.layer,
seq.len=seq.len,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=vocab,
batch.size=batch.size,
input.size=vocab,
initializer=mx.init.uniform(0.01),
learning.rate=learning.rate,
wd=wd)