| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # Sherlock Holmes Language Model using lstm, you can replace mx.lstm by mx.gru/ mx.rnn to use gru/rnn model |
| # The data file can be found at: |
| # https://github.com/dmlc/web-data/tree/master/mxnet/sherlockholmes |
| require(hash) |
| require(mxnet) |
| require(stringr |
| |
| ) |
| |
| load.data <- function(path, dic=NULL) { |
| fi <- file(path, "r") |
| content <- paste(readLines(fi), collapse="<eos>") |
| close(fi) |
| #cat(content) |
| content <- str_split(content, ' ')[[1]] |
| cat(paste0("Loading ", path, ", size of data = ", length(content), "\n")) |
| X <- array(0, dim=c(length(content))) |
| #cat(X) |
| if (is.null(dic)) |
| dic <- hash() |
| idx <- 1 |
| for (i in 1:length(content)) { |
| word <- content[i] |
| if (str_length(word) > 0) { |
| if (!has.key(word, dic)) { |
| dic[[word]] <- idx |
| idx <- idx + 1 |
| } |
| X[i] <- dic[[word]] |
| } |
| } |
| cat(paste0("Unique token: ", length(dic), "\n")) |
| return (list(X=X, dic=dic)) |
| } |
| |
| |
| replicate.data <- function(X, seq.len) { |
| num.seq <- as.integer(length(X) / seq.len) |
| X <- X[1:(num.seq*seq.len)] |
| print |
| dim(X) = c(seq.len, num.seq) |
| return (X) |
| } |
| |
| drop.tail <- function(X, batch.size) { |
| shape <- dim(X) |
| nstep <- as.integer(shape[2] / batch.size) |
| return (X[, 1:(nstep * batch.size)]) |
| } |
| |
| get.label <- function(X) { |
| label <- array(0, dim=dim(X)) |
| d <- dim(X)[1] |
| w <- dim(X)[2] |
| for (i in 0:(w-1)) { |
| for (j in 1:d) { |
| label[i*d+j] <- X[(i*d+j)%%(w*d)+1] |
| } |
| } |
| return (label) |
| } |
| |
| batch.size = 20 |
| seq.len = 35 |
| num.hidden = 200 |
| num.embed = 200 |
| num.lstm.layer = 2 |
| num.round = 15 |
| learning.rate= 0.1 |
| wd=0.00001 |
| update.period = 1 |
| |
| |
| train <- load.data("./data/sherlockholmes.train.txt") |
| X.train <- train$X |
| dic <- train$dic |
| val <- load.data("./data/sherlockholmes.valid.txt", dic) |
| X.val <- val$X |
| dic <- val$dic |
| X.train.data <- replicate.data(X.train, seq.len) |
| X.val.data <- replicate.data(X.val, seq.len) |
| vocab <- length(dic) |
| cat(paste0("Vocab=", vocab, "\n")) |
| |
| X.train.data <- drop.tail(X.train.data, batch.size) |
| X.val.data <- drop.tail(X.val.data, batch.size) |
| X.train.label <- get.label(X.train.data) |
| X.val.label <- get.label(X.val.data) |
| X.train <- list(data=X.train.data, label=X.train.label) |
| X.val <- list(data=X.val.data, label=X.val.label) |
| |
| model <- mx.lstm(X.train, X.val, |
| ctx=mx.gpu(0), |
| num.round=num.round, |
| update.period=update.period, |
| num.lstm.layer=num.lstm.layer, |
| seq.len=seq.len, |
| num.hidden=num.hidden, |
| num.embed=num.embed, |
| num.label=vocab, |
| batch.size=batch.size, |
| input.size=vocab, |
| initializer=mx.init.uniform(0.01), |
| learning.rate=learning.rate, |
| wd=wd) |
| |