blob: 2087195fc0e0981881954a08e58f91e04d38cfaa [file] [log] [blame]
# Char RNN 的例子
这个例子旨在展示如何使用`lstm`模型来构建一个字符级语言模型,并且从中生成文本。我们使用一个极小的莎士比亚文本作为Demo
数据源可以在这里找到: <https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare>.
## 前言
本教程由Rmd编辑完成。
- 你可以直接访问主站版本的教程:[MXNet R Document](http://mxnet.io/tutorials/r/charRnnModel.html)
- 你也可以从[这里](https://github.com/dmlc/mxnet/blob/master/R-package/vignettes/CharRnnModel.Rmd) 下载到Rmarkdown源文件
## 加载数据
首先,加载数据并做预处理。
```{r}
require(mxnet)
```
设置基本的网络参数。
```{r}
batch.size = 32
seq.len = 32
num.hidden = 256
num.embed = 256
num.lstm.layer = 2
num.round = 3
learning.rate= 0.1
wd=0.00001
clip_gradient=1
update.period = 1
```
下载数据。
```{r}
download.data <- function(data_dir) {
dir.create(data_dir, showWarnings = FALSE)
if (!file.exists(paste0(data_dir,'input.txt'))) {
download.file(url='https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt',
destfile=paste0(data_dir,'input.txt'), method='wget')
}
}
```
从文本生成字典
```{r}
make.dict <- function(text, max.vocab=10000) {
text <- strsplit(text, '')
dic <- list()
idx <- 1
for (c in text[[1]]) {
if (!(c %in% names(dic))) {
dic[[c]] <- idx
idx <- idx + 1
}
}
if (length(dic) == max.vocab - 1)
dic[["UNKNOWN"]] <- idx
cat(paste0("Total unique char: ", length(dic), "\n"))
return (dic)
}
```
将文本转化成数据特征。
```{r}
make.data <- function(file.path, seq.len=32, max.vocab=10000, dic=NULL) {
fi <- file(file.path, "r")
text <- paste(readLines(fi), collapse="\n")
close(fi)
if (is.null(dic))
dic <- make.dict(text, max.vocab)
lookup.table <- list()
for (c in names(dic)) {
idx <- dic[[c]]
lookup.table[[idx]] <- c
}
char.lst <- strsplit(text, '')[[1]]
num.seq <- as.integer(length(char.lst) / seq.len)
char.lst <- char.lst[1:(num.seq * seq.len)]
data <- array(0, dim=c(seq.len, num.seq))
idx <- 1
for (i in 1:num.seq) {
for (j in 1:seq.len) {
if (char.lst[idx] %in% names(dic))
data[j, i] <- dic[[ char.lst[idx] ]]-1
else {
data[j, i] <- dic[["UNKNOWN"]]-1
}
idx <- idx + 1
}
}
return (list(data=data, dic=dic, lookup.table=lookup.table))
}
```
将长尾数据截去
```{r}
drop.tail <- function(X, batch.size) {
shape <- dim(X)
nstep <- as.integer(shape[2] / batch.size)
return (X[, 1:(nstep * batch.size)])
}
```
获取 X 的标签
```{r}
get.label <- function(X) {
label <- array(0, dim=dim(X))
d <- dim(X)[1]
w <- dim(X)[2]
for (i in 0:(w-1)) {
for (j in 1:d) {
label[i*d+j] <- X[(i*d+j)%%(w*d)+1]
}
}
return (label)
}
```
获取训练数据
```{r}
download.data("./data/")
ret <- make.data("./data/input.txt", seq.len=seq.len)
X <- ret$data
dic <- ret$dic
lookup.table <- ret$lookup.table
vocab <- length(dic)
shape <- dim(X)
train.val.fraction <- 0.9
size <- shape[2]
X.train.data <- X[, 1:as.integer(size * train.val.fraction)]
X.val.data <- X[, -(1:as.integer(size * train.val.fraction))]
X.train.data <- drop.tail(X.train.data, batch.size)
X.val.data <- drop.tail(X.val.data, batch.size)
X.train.label <- get.label(X.train.data)
X.val.label <- get.label(X.val.data)
X.train <- list(data=X.train.data, label=X.train.label)
X.val <- list(data=X.val.data, label=X.val.label)
```
## 训练模型
`mxnet`, 我们有一个 `mx.lstm` 函数,因此我们可以用这个函数来建立一个`lstm` 模型。
```{r}
model <- mx.lstm(X.train, X.val,
ctx=mx.cpu(),
num.round=num.round,
update.period=update.period,
num.lstm.layer=num.lstm.layer,
seq.len=seq.len,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=vocab,
batch.size=batch.size,
input.size=vocab,
initializer=mx.init.uniform(0.1),
learning.rate=learning.rate,
wd=wd,
clip_gradient=clip_gradient)
```
设置参数 `ctx=mx.gpu(0)``num.round=5`可以得到下面的结果。
```
Epoch [31] Train: NLL=3.47213018872144, Perp=32.2052727363657
...
Epoch [961] Train: NLL=2.32060007657895, Perp=10.181782322355
Iter [1] Train: Time: 186.397065639496 sec, NLL=2.31135356537961, Perp=10.0880702804858
Iter [1] Val: NLL=1.94184484060012, Perp=6.97160060607419
Epoch [992] Train: NLL=1.84784553299322, Perp=6.34613225095329
...
Epoch [1953] Train: NLL=1.70175791172558, Perp=5.48357857093351
Iter [2] Train: Time: 188.929051160812 sec, NLL=1.70103940328978, Perp=5.47963998859367
Iter [2] Val: NLL=1.74979316010449, Perp=5.75341251767988
...
Epoch [2914] Train: NLL=1.54738185300295, Perp=4.69915099483974
Iter [3] Train: Time: 185.425321578979 sec, NLL=1.54604189517013, Perp=4.69285854740519
Iter [3] Val: NLL=1.67780240235925, Perp=5.35377758479576
Epoch [2945] Train: NLL=1.48868466087876, Perp=4.43126307034767
...
Iter [4] Train: Time: 185.487086296082 sec, NLL=1.4744973925858, Perp=4.36883940994296
Iter [4] Val: NLL=1.64488167325603, Perp=5.18039689118454
Epoch [3937] Train: NLL=1.46355541021581, Perp=4.32129622881604
...
Epoch [4898] Train: NLL=1.42900458455642, Perp=4.17454171976281
Iter [5] Train: Time: 185.070136785507 sec, NLL=1.42909226256273, Perp=4.17490775130428
Iter [5] Val: NLL=1.62716655804022, Perp=5.08943365437187
```
## 模型推理
随机样本辅助函数。
```{r}
cdf <- function(weights) {
total <- sum(weights)
result <- c()
cumsum <- 0
for (w in weights) {
cumsum <- cumsum+w
result <- c(result, cumsum / total)
}
return (result)
}
search.val <- function(cdf, x) {
l <- 1
r <- length(cdf)
while (l <= r) {
m <- as.integer((l+r)/2)
if (cdf[m] < x) {
l <- m+1
} else {
r <- m-1
}
}
return (l)
}
choice <- function(weights) {
cdf.vals <- cdf(as.array(weights))
x <- runif(1)
idx <- search.val(cdf.vals, x)
return (idx)
}
```
我们可以依据最大概率选择随机输出或者固定输出。
```{r}
make.output <- function(prob, sample=FALSE, temperature=1.) {
if (!sample) {
idx <- which.max(as.array(prob))
}
else {
scale_prob <- mx.nd.clip(prob, 1e-6, 1 - 1e-6)
rescale <- mx.nd.exp(mx.nd.log(scale_prob) / temperature)
rescale <- rescale / (as.array(mx.nd.sum(rescale))[1])
idx <- choice(rescale)
}
return (idx)
}
```
`mxnet`中, 我们还有一个叫做 `mx.lstm.inference` 的函数,因此我们还可以由`lstm`模型建立一个推理,然后使用 `mx.lstm.forward` 函数从推理中进一步得到结果。
```{r}
infer.model <- mx.lstm.inference(num.lstm.layer=num.lstm.layer,
input.size=vocab,
num.hidden=num.hidden,
num.embed=num.embed,
num.label=vocab,
arg.params=model$arg.params,
ctx=mx.cpu())
```
使用函数 `mx.lstm.forward` 生成一个75个字符的序列
```
start <- 'a'
seq.len <- 75
random.sample <- TRUE
last.id <- dic[[start]]
out <- "a"
for (i in (1:(seq.len-1))) {
input <- c(last.id-1)
ret <- mx.lstm.forward(infer.model, input, FALSE)
infer.model <- ret$model
prob <- ret$prob
last.id <- make.output(prob, random.sample)
out <- paste0(out, lookup.table[[last.id]])
}
cat (paste0(out, "\n"))
```
结果:
```
ah not a drobl greens
Settled asing lately sistering sounted to their hight
```